# 0. Prepare the dependencies

In [284]:
# 1. install pytorch
  ## Please follow instruction in https://pytorch.org/get-started/locally/

# 2. install torch-geometric
!pip install torch-geometric



In [12]:
import numpy as np
import os
import sys
import torch
from torch.utils.data import Dataset, DataLoader, DistributedSampler, WeightedRandomSampler
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import torch.nn as nn
from functools import lru_cache

# 1. Prepare data structure and configurations

In [2]:
tree_records_base_path = "../data/serialized_tree"

In [3]:
class SyntaxTreeNode:
    def __init__(self, value, left=None, right=None):
        self.value = value  # Tuple of 6 values
        self.left = left
        self.right = right
    @property
    def children(self):
        return [self.left, self.right]
        
    def __repr__(self):
        return f"SyntaxTreeNode(value={self.value}, left={self.left}, right={self.right})"


def deserialize_tree(tree_str):
    tokens = tree_str.split()
    index = 0

    def _deserialize_helper():
        nonlocal index
        if index >= len(tokens):
            return None

        if tokens[index] == "#":
            index += 1
            return None

        # Extract the 6 values for the current node
        value = (
            int(tokens[index]),     # std::get<0>
            int(tokens[index + 1]), # std::get<1>
            int(tokens[index + 2]), # std::get<2>
            int(tokens[index + 3]), # std::get<3>
            float(tokens[index + 4]), # std::get<4>
            int(tokens[index + 5])  # std::get<5>
        )
        index += 6

        # Recursively deserialize left and right children
        left = _deserialize_helper()
        right = _deserialize_helper()

        return SyntaxTreeNode(value, left, right)

    return _deserialize_helper()


In [4]:
def tree_to_graph(root):
    """
    Convert a SyntaxTreeNode to a PyTorch Geometric graph.
    """
    
    nodes = []
    edges = []
    node_id_map = {}

    def dfs(node, parent_id=None):
        if (node is None):
            return
            
        # Assign a unique ID to the node
        node_id = len(nodes)
        node_id_map[node] = node_id
        nodes.append(node)

        # Add edge from parent to current node
        if parent_id is not None:
            edges.append((parent_id, node_id))

        # Recursively process children
        for child in node.children:
            dfs(child, node_id)

    dfs(root)
    
    x = []
    for node in nodes:
        # Assume node.features is an order-6 tuple
        features = node.value

        # Separate dimensions
        feature_0 = (int)(features[0]) + 1
        feature_0 = 0 if feature_0 == 65536 else feature_0
        
        feature_1 = (int)(features[1]) + 1
        feature_1 = 0 if feature_1 == 65536 else feature_1
        
        feature_2 = (int)(features[2]) + 1
        feature_2 = 0 if feature_2 == 65536 else feature_2

        feature_5 = (int)(features[5]) + 1
        feature_5 = 0 if feature_5 == 65536 else feature_5
   
        possibility = features[4]
        
        
        # Concatenate all features into a single vector
        node_features = torch.tensor([possibility, feature_0, feature_1, feature_2, feature_5], dtype = torch.float)
        x.append(node_features)

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()  # Edge indices

    return Data(
        x=torch.stack(x),  # Node feature matrix
        edge_index=edge_index # Edge index
    )

In [5]:
# Example serialized tree string
tree_str = "1 2 3 4 5.0 6 7 8 9 10 11.0 12 # # #"

# Deserialize the tree
root = deserialize_tree(tree_str)
print(root)
# Convert the tree to a graph
graph = tree_to_graph(root)

# Print the graph
print(graph)
print("Node features:", graph.x)
print("Edge index:", graph.edge_index)

SyntaxTreeNode(value=(1, 2, 3, 4, 5.0, 6), left=SyntaxTreeNode(value=(7, 8, 9, 10, 11.0, 12), left=None, right=None), right=None)
Data(x=[2, 5], edge_index=[2, 1])
Node features: tensor([[ 5.,  2.,  3.,  4.,  7.],
        [11.,  8.,  9., 10., 13.]])
Edge index: tensor([[0],
        [1]])


In [9]:
from torch_geometric.data import Batch

class TreeDataset(Dataset):
    
    def __init__(self, area_types):
        super(TreeDataset, self).__init__()
        files = []
        labels = []
        self.label_map = {
            'normal': 0,
            'seizure': 1,
            'pre-epileptic': 2
        }
        
        for area_type in area_types:
            dataset_base_path = os.path.join(tree_records_base_path, area_type)
            files_this_area_type = ([os.path.join(dataset_base_path, file) for file in os.listdir(dataset_base_path)])
            files += files_this_area_type
            labels += [self.label_map[area_type]] * len(files_this_area_type)
            print(f'Add {len(files_this_area_type)} for category {area_type}')
            
        self.files = files
        self.labels = labels
        assert(len(self.files) == len(self.labels))

    def __len__(self):
        return len(self.files)

    @lru_cache(maxsize=None)
    def __getitem__(self, idx):
        if(idx >= len(self.files)):
            raise ValueError(f'idx = {idx} >= total amount of files = {len(self.files)}')
        file = self.files[idx]
        with open(file, "r") as f:
            serialized_tree = f.read().strip()
        root = deserialize_tree(serialized_tree)
        graph = tree_to_graph(root)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return {"graph": graph, "label":label}


# 2. Prepare Dataset

In [10]:
dataset_types = ["normal", "seizure", "pre-epileptic"]

dataset = TreeDataset(dataset_types)


Add 565638 for category normal
Add 28084 for category seizure
Add 152313 for category pre-epileptic


In [13]:
from sklearn.model_selection import train_test_split
import torch
# Assuming 'dataset' is an instance of TreeDataset
train_idx, temp_idx = train_test_split(
    list(range(len(dataset))), 
    test_size=0.2, 
    stratify=dataset.labels
)

val_idx, test_idx = train_test_split(
    temp_idx, 
    test_size=0.5, 
    stratify=[dataset.labels[i] for i in temp_idx]  # This is fine since temp_idx contains the original indices
)

# Create subsets
train_subset = torch.utils.data.Subset(dataset, train_idx)
val_subset = torch.utils.data.Subset(dataset, val_idx)
test_subset = torch.utils.data.Subset(dataset, test_idx)


# Now we calculate the weights based on the train_subset (not the full dataset)
train_labels = [dataset.labels[i] for i in train_idx]

# Calculate class counts for the training set
class_counts = [0, 0, 0]  # For normal, seizure, pre-epileptic
for label in train_labels:
    class_counts[label] += 1

total_samples = len(train_labels)
class_weights = [total_samples / count for count in class_counts]

# Calculate sample weights for the training set
sample_weights = [class_weights[label] for label in train_labels]

# Create the sampler for the training set
train_sampler = WeightedRandomSampler(sample_weights, len(train_labels), replacement=True)


# Print sizes to confirm
print(f"Train size: {len(train_subset)}")
print(f"Validation size: {len(val_subset)}")
print(f"Test size: {len(test_subset)}")


Train size: 596828
Validation size: 74603
Test size: 74604


In [14]:
def custom_collate(batch):
    """
    Custom collate function to batch torch_geometric.data.Data objects.
    Args:
        batch: List of dictionaries containing "graph" and "label".
    Returns:
        Batched graphs and labels.
    """
    graphs = [item["graph"] for item in batch]
    labels = [item["label"] for item in batch]
    
    # Batch graphs using PyTorch Geometric's Batch class
    batched_graphs = Batch.from_data_list(graphs)
    
    # Stack labels into a tensor
    batched_labels = torch.stack(labels)
    
    return batched_graphs, batched_labels
    
# Create DataLoader for the training set with WeightedRandomSampler
train_loader = DataLoader(
    train_subset,
    batch_size=32,
    sampler=train_sampler,
    collate_fn=custom_collate  # Use custom collate function
)

# Create DataLoader for the validation and test sets without any sampling (just shuffle them)
val_loader = DataLoader(
    val_subset,
    batch_size=32,
    shuffle=True,
    collate_fn=custom_collate  # Use custom collate function
)

test_loader = DataLoader(
    test_subset,
    batch_size=32,
    shuffle=True,
    collate_fn=custom_collate  # Use custom collate function
)


In [None]:
# Function to print the count of each label in the dataset
def print_label_counts(loader, dataset_type="train"):
    # Initialize label counts
    label_counts = {0: 0, 1: 0, 2: 0}  # Assuming 3 classes (normal=0, seizure=1, pre-epileptic=2)
    
    # Iterate over the dataset in the loader to count each label
    for data, labels in tqdm(loader):
        for label in labels:
            label_counts[label.item()] += 1
    print(f"{dataset_type} set label counts:")
    for label, count in label_counts.items():
        print(f"Class {label}: {count} samples")
    print()

# Print label counts for train, validation, and test loaders
print_label_counts(train_loader, dataset_type="train")
print_label_counts(val_loader, dataset_type="validation")
print_label_counts(test_loader, dataset_type="test")


# 3. Prepare Model

In [17]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool


In [52]:

class TreeGNN(torch.nn.Module):
    def __init__(self, hidden_dim=128, num_classes = 3):
        super(TreeGNN, self).__init__()
        # Define embedding layers
        self.dim0_embedding = nn.Embedding(num_embeddings=96, embedding_dim=32)
        self.dim1_embedding = nn.Embedding(num_embeddings=96, embedding_dim=32)
        self.dim2_embedding = nn.Embedding(num_embeddings=96, embedding_dim=32)
        self.dim5_embedding = nn.Embedding(num_embeddings=182, embedding_dim=32)

        # Define GNN layers
        self.conv1 = GCNConv(129, hidden_dim)  # Input size: 1 + 32*4 = 129
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)


    def forward(self, x, edge_index, batch):
        # Separate features
        possibility = x[:, 0].unsqueeze(1)  # Shape: [num_nodes, 1]
        dim0 = x[:, 1].long()
        dim1 = x[:, 2].long()
        dim2 = x[:, 3].long()
        dim5 = x[:, 4].long()

        # Embed categorical features
        dim0_embedded = self.dim0_embedding(dim0)  # Shape: [num_nodes, 16]
        dim1_embedded = self.dim1_embedding(dim1)  # Shape: [num_nodes, 16]
        dim2_embedded = self.dim2_embedding(dim2)  # Shape: [num_nodes, 16]
        dim5_embedded = self.dim5_embedding(dim5)  # Shape: [num_nodes, 16]

        # Concatenate all features
        x = torch.cat([possibility, dim0_embedded, dim1_embedded, dim2_embedded, dim5_embedded], dim=1)
        # Pass through GNN layers
        x = F.relu(self.bn1(self.conv1(x, edge_index)))  # Shape: [num_nodes, hidden_dim]
        x = F.relu(self.bn2(self.conv2(x, edge_index)))          # Shape: [num_nodes, output_dim]
        x = global_mean_pool(x, batch)  # Aggregate node features into graph-level features

        x = self.fc(x)
        return x


In [75]:
model = TreeGNN(hidden_dim=128,  num_classes = 3)

In [76]:
import torch
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import numpy as np

class_weights = torch.tensor([1.0, 2.0, 1.0])  # Higher weight for class 1 (seizure)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

# Cross-Entropy Loss for classification
optimizer = optim.Adam(model.parameters(), lr=0.001)
def calculate_metrics(conf_matrix):
    """
    Calculate accuracy, TPR, FPR, TNR, FNR, F1, and F2 from a confusion matrix.
    Args:
        conf_matrix: Confusion matrix (3x3 for 3 classes).
    Returns:
        Dictionary of metrics.
    """
    metrics = {}
    
    # True Positives (diagonal of the confusion matrix)
    TP = np.diag(conf_matrix)
    
    # False Positives (sum of columns minus diagonal)
    FP = np.sum(conf_matrix, axis=0) - TP
    
    # False Negatives (sum of rows minus diagonal)
    FN = np.sum(conf_matrix, axis=1) - TP
    
    # True Negatives (total samples minus TP, FP, FN)
    TN = np.sum(conf_matrix) - (TP + FP + FN)
    
    # Accuracy
    metrics["accuracy"] = np.sum(TP) / np.sum(conf_matrix)
    
    # True Positive Rate (Recall)
    metrics["TPR"] = np.divide(TP, TP + FN, where=(TP + FN) != 0)
    
    # False Positive Rate
    metrics["FPR"] = np.divide(FP, FP + TN, where=(FP + TN) != 0)
    
    # True Negative Rate
    metrics["TNR"] = np.divide(TN, TN + FP, where=(TN + FP) != 0)
    
    # False Negative Rate
    metrics["FNR"] = np.divide(FN, TP + FN, where=(TP + FN) != 0)
    
    # Precision
    precision = np.divide(TP, TP + FP, where=(TP + FP) != 0)
    
    # F1 Score
    metrics["F1"] = np.divide(2 * (precision * metrics["TPR"]), (precision + metrics["TPR"]), where=(precision + metrics["TPR"]) != 0)
    
    # F2 Score
    metrics["F2"] = np.divide(5 * (precision * metrics["TPR"]), (4 * precision + metrics["TPR"]), where=(4 * precision + metrics["TPR"]) != 0)
    
    return metrics

# Training loop
for epoch in range(100):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        # Forward pass: get predictions
        graphs, labels = batch['graph'], batch['label']
        output = model(graphs.x, graphs.edge_index, graphs.batch)
        
        # Compute the loss
        loss = criterion(output, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        
        # Store predictions and labels for metrics
        preds = torch.argmax(output, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    # Calculate training metrics
    train_conf_matrix = confusion_matrix(all_labels, all_preds)
    train_metrics = calculate_metrics(train_conf_matrix)
    
    print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")
    print(f"Training Metrics: {train_metrics}")

    # Validation
    model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader):
            graphs, labels = batch['graph'], batch['label']
            output = model(graphs.x, graphs.edge_index, graphs.batch)
            loss = criterion(output, labels)
            val_loss += loss.item()
            
            # Store predictions and labels for metrics
            preds = torch.argmax(output, dim=1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    # Calculate validation metrics
    val_conf_matrix = confusion_matrix(val_labels, val_preds)
    val_metrics = calculate_metrics(val_conf_matrix)
    
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Validation Metrics: {val_metrics}")

  0%|                                        | 50/18651 [00:02<15:44, 19.69it/s]

KeyboardInterrupt



In [53]:


class TreeContrastiveModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim):
        super(TreeContrastiveModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, embedding_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return global_mean_pool(x, torch.zeros(x.size(0), dtype=torch.long))  # Global pooling

# Contrastive loss (NT-Xent)
def contrastive_loss(z1, z2, temperature=0.5):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    logits = torch.matmul(z1, z2.T) / temperature
    labels = torch.arange(z1.size(0)).to(z1.device)
    loss = F.cross_entropy(logits, labels)
    return loss

# Example usage
model = TreeContrastiveModel(input_dim=6, hidden_dim=16, embedding_dim=8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(100):
    for batch in dataloader:
        optimizer.zero_grad()
        # Create two views of the same tree (e.g., by random masking)
        view1 = batch  # Original graph
        view2 = augment_tree(batch)  # Augmented graph
        # Encode both views
        z1 = model(view1.x, view1.edge_index)
        z2 = model(view2.x, view2.edge_index)
        # Compute contrastive loss
        loss = contrastive_loss(z1, z2)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")

NameError: name 'augment_tree' is not defined

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class VGAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim):
        super(VGAE, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv_mu = GCNConv(hidden_dim, embedding_dim)
        self.conv_logvar = GCNConv(hidden_dim, embedding_dim)

    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        mu = self.conv_mu(x, edge_index)
        logvar = self.conv_logvar(x, edge_index)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, batch):
        adj_reconstructed = []
        for graph_idx in torch.unique(batch):
            graph_mask = (batch == graph_idx)
            z_graph = z[graph_mask]
            if z_graph.size(0) == 0:  # Skip empty graphs
                continue
            adj_graph = torch.sigmoid(torch.matmul(z_graph, z_graph.t()))
            adj_reconstructed.append(adj_graph)
        return adj_reconstructed

    def forward(self, x, edge_index, batch):
        mu, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

# Loss function
def vgae_loss(z, mu, logvar, edge_index, batch, kl_weight):
    adj_reconstructed = decode(z, batch)
    
    recon_loss = 0
    for graph_idx in torch.unique(batch):
        graph_mask = (batch == graph_idx)
        num_nodes = graph_mask.sum().item()
        
        # Filter edges for the current graph
        edge_mask = (batch[edge_index[0]] == graph_idx) & (batch[edge_index[1]] == graph_idx)
        edge_index_graph = edge_index[:, edge_mask]
        
        # Adjust edge indices to be relative to the current graph
        edge_index_graph = edge_index_graph - edge_index_graph.min()
        
        # Create the true adjacency matrix
        adj_true = torch.zeros(num_nodes, num_nodes).to(z.device)
        adj_true[edge_index_graph[0], edge_index_graph[1]] = 1
        
        # Compute reconstruction loss
        recon_loss += F.binary_cross_entropy(adj_reconstructed[graph_idx], adj_true)
    
    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss +  kl_loss

def kl_annealing(epoch, max_epochs):
    return min(epoch / max_epochs, 1.0)  # Linear annealing


model = VGAE(input_dim=6, hidden_dim=256, embedding_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


# Training loop
max_epochs = 100
for epoch in range(max_epochs):
    kl_weight = kl_annealing(epoch, max_epochs)
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        z, mu, logvar = model(batch.x, batch.edge_index, batch.batch)
        loss = vgae_loss(z, mu, logvar, batch.edge_index, batch.batch, kl_weight)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}, KL Weight: {kl_weight}")