# 0. Prepare the dependencies

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

In [None]:
# 1. install pytorch
  ## Please follow instruction in https://pytorch.org/get-started/locally/

# 2. install torch-geometric
!pip install torch-geometric

In [1]:
from tqdm import tqdm
import torch

# 1. Prepare data structure and configurations

In [2]:
from data_structures.tree import SyntaxTreeNode
from dataset.tree_dataset import TreeDataset

# 2. Prepare Dataset

In [3]:
from dataset.samplers import ProportionalWeightedRandomSamplerBuilder
from dataset.dataset_splitter import RatioBasedDataSetSplitter
from dataset.data_loader_builder import default_supervised_collate_fn, DataLoaderBuilder

In [1]:
n_classes = 3

In [None]:
tree_records_base_path = "../data/serialized_tree"
dataset_types = ["normal-retained", "seizure", "pre-epileptic"]
dataset = TreeDataset(dataset_types, tree_records_base_path)

In [None]:
dataset_splitter = RatioBasedDataSetSplitter(dataset)

In [None]:
splitted_dataset = dataset_splitter.split_dataset()

In [None]:
train_subset = splitted_dataset["train_set"]
val_subset = splitted_dataset["val_set"]
test_subset = splitted_dataset["test_set"]

In [None]:
train_labels = [dataset.labels[i] for i in splitted_dataset["train_set_indexes"]]

In [None]:
sampler_builder = ProportionalWeightedRandomSamplerBuilder()
sampler, weights = sampler_builder.build(n_classes = 3, labels = train_labels, return_weights = True)

In [None]:
weights, sampler

In [None]:
# A function that describe how pytorch to generate a representation of batch.
# This function if needed if the dataset return a data structure that the pytorch cannot recognize, e.g., the TreeNode above.
def collate_fn(batch):
    # Extract trees and labels from the batch
    trees = [item["tree"] for item in batch]
    labels = torch.stack([item["labels"] for item in batch])

    # Return the batch as a dictionary
    return {"tree": trees, "labels": labels}

In [None]:
data_loader_builder = DataLoaderBuilder()
data_loaders = data_loader_builder.build(train_subset, val_subset, test_subset, train_sampler = sampler, batch_size = 32,
                         collate_fn_train = collate_fn, collate_fn_val = collate_fn, collate_fn_test = collate_fn)

In [None]:
train_loader = data_loaders["train_loader"]
val_loader = data_loaders["val_loader"]
test_loader = data_loaders["test_loader"]

In [None]:
# Function to print the count of each label in the dataset
def print_label_counts(loader, dataset_type="train"):
    # Initialize label counts
    label_counts = {0: 0, 1: 0, 2: 0}  # Assuming 3 classes (normal=0, seizure=1, pre-epileptic=2)
    
    # Iterate over the dataset in the loader to count each label
    for sample in tqdm(loader):
        labels = sample["labels"]
        for label in labels:
            label_counts[label.item()] += 1
    for label, count in label_counts.items():
        print(f"Class {label}: {count} samples")

# # Print label counts for train, validation, and test loaders
# print_label_counts(train_loader, dataset_type="train")
# print_label_counts(val_loader, dataset_type="validation")
# print_label_counts(test_loader, dataset_type="test")

# 3. Prepare Model

In [None]:
from models.tree_lstm import SeizurePredictionInputEmbeddingPreprocessor, BinaryTreeLSTMCell, BinaryTreeLSTM
from utils.utils import calculate_metrics

In [None]:
import torch.optim as optim
import torch
from torch import nn
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import numpy as np

# Hyperparameters
input_size = 32 * 3 + 64 + 1  # Size of the node value tuple
hidden_size = 64
num_classes = 3  # Normal, seizure, pre-epileptic
learning_rate = 0.001
num_epochs = 10

# Initialize model, loss function, and optimizer
embedding_model = SeizurePredictionInputEmbeddingPreprocessor(unique_symbols = 96, \
                                                              symbol_embedding_size = 32, \
                                                              unique_grammar = 182, \
                                                              grammar_embedding_size = 64)
model = BinaryTreeLSTM(input_size, hidden_size, num_classes, input_embedding_model = embedding_model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
def forwarding(batch):
    trees = batch["tree"]
    labels = batch["labels"]
    logits, nodes, edges = model(trees)
    print(nodes, edges)
    loss = criterion(logits, labels)
    return logits, labels, loss

enable_summary_confusion_matrix = True

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    if enable_summary_confusion_matrix:
        all_preds = []
        all_labels = []
    
    for batch in tqdm(train_loader):
        # Forward pass
        logits, labels, loss = forwarding(batch)
        
        raise ValueError
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if enable_summary_confusion_matrix:
            # Store predictions and labels for metrics
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    if enable_summary_confusion_matrix:
        # Calculate training metrics
        train_conf_matrix = confusion_matrix(all_labels, all_preds)
        train_metrics = calculate_metrics(train_conf_matrix)
    
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")
        print(f"Training Metrics: {train_metrics}")
    print(f"Training Loss: {total_loss:.4f}")
    
    if enable_summary_confusion_matrix:
        print(f"Training Metrics: {train_metrics}")
    
    # Validation
    model.eval()
    val_loss = 0
    if enable_summary_confusion_matrix:
        val_preds = []
        val_labels = []

    with torch.no_grad():
        for batch in tqdm(val_loader):
            trees = batch["tree"]
            labels = batch["labels"]
    
            # Forward pass
            logits = model(trees)
            loss = criterion(logits, labels)
            
            val_loss += loss.item()
            
            if enable_summary_confusion_matrix:
                preds = torch.argmax(logits, dim=1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
    
    if enable_summary_confusion_matrix:
        val_conf_matrix = confusion_matrix(val_labels, val_preds)
        val_metrics = calculate_metrics(val_conf_matrix)
        
    print(f"Validation Loss: {val_loss:.4f}")
    
    if enable_summary_confusion_matrix:
        print(f"Validation Metrics: {val_metrics}")
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")