In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, GATConv, TransformerConv
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.loader import DataLoader
import math
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd

In [60]:
# 1. Positional Encoding for Graph Nodes: GraphPositionalEncoding
class GraphPositionalEncoding(nn.Module):
    """
    Add positional encoding to graph nodes (temporal + spatial)
    """
    def __init__(self, d_model, max_nodes=1000, max_time=500):
        super().__init__()
        self.d_model = d_model

        # Temporal positional encoding
        self.temporal_pe = nn.Parameter(torch.zeros(max_time, d_model // 2))

        # Spatial positional encoding (learnable)
        self.spatial_pe = nn.Parameter(torch.zeros(max_nodes, d_model // 2))

        self._init_positional_encoding()

    def _init_positional_encoding(self):
        """Initialize with sinusoidal encoding"""
        max_time = self.temporal_pe.shape[0]
        d_model = self.temporal_pe.shape[1]

        position = torch.arange(0, max_time, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        self.temporal_pe.data[:, 0::2] = torch.sin(position * div_term)
        self.temporal_pe.data[:, 1::2] = torch.cos(position * div_term)

        # Initialize spatial randomly (will be learned)
        nn.init.normal_(self.spatial_pe, mean=0, std=0.02)

    def forward(self, x, node_ids, time_ids):
        """
        Args:
            x: (num_nodes, feature_dim)
            node_ids: (num_nodes,) - spatial node identifiers
            time_ids: (num_nodes,) - temporal timestep identifiers
        """
        batch_size = x.shape[0]

        # Get temporal and spatial encodings
        temporal_enc = self.temporal_pe[time_ids]  # (num_nodes, d_model//2)
        spatial_enc = self.spatial_pe[node_ids]    # (num_nodes, d_model//2)

        # Concatenate temporal and spatial
        pe = torch.cat([temporal_enc, spatial_enc], dim=-1)  # (num_nodes, d_model)

        return pe


In [61]:
# 2. Graph Transformer Encoder Layer: GraphTransformerLayer
class GraphTransformerLayer(nn.Module):
    """
    Transformer layer that operates on graph structure
    Combines GNN message passing with transformer attention
    (Basically sparser version of multiheaded transformer over the actual spatial network)
    """
    def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1):
        super().__init__()

        # Graph convolution for local structure
        self.graph_conv = GATConv(d_model, d_model, heads=nhead,
                                   concat=False, dropout=dropout)

        # Self-attention for global dependencies
        self.self_attn = nn.MultiheadAttention(d_model, nhead,
                                               dropout=dropout, batch_first=True)

        # Feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, edge_index, batch, edge_attr=None):
        """
        Args:
            x: (num_nodes, d_model) - node features
            edge_index: (2, num_edges) - graph connectivity
            batch: (num_nodes,) - batch assignment
            edge_attr: (num_edges, edge_dim) - edge features
        """
        # Graph convolution for local structure
        x2 = self.graph_conv(x, edge_index, edge_attr)
        x = x + self.dropout1(x2)
        x = self.norm1(x)

        # Self-attention for global context
        # Group nodes by graph for attention
        batch_size = batch.max().item() + 1
        max_nodes = max([(batch == i).sum() for i in range(batch_size)])

        # Create padded batch for attention
        x_batched = torch.zeros(batch_size, max_nodes, x.shape[1],
                               device=x.device)
        mask = torch.ones(batch_size, max_nodes, dtype=torch.bool,
                         device=x.device)

        for i in range(batch_size):
            nodes_in_graph = (batch == i).sum()
            x_batched[i, :nodes_in_graph] = x[batch == i]
            mask[i, :nodes_in_graph] = False

        # Apply self-attention
        x_attn, _ = self.self_attn(x_batched, x_batched, x_batched,
                                    key_padding_mask=mask)

        # Unpack back to node format
        x_attn_unpacked = torch.zeros_like(x)
        for i in range(batch_size):
            nodes_in_graph = (batch == i).sum()
            x_attn_unpacked[batch == i] = x_attn[i, :nodes_in_graph]

        x = x + self.dropout2(x_attn_unpacked)
        x = self.norm2(x)

        # Feedforward
        x2 = self.linear2(self.dropout(F.gelu(self.linear1(x))))
        x = x + self.dropout3(x2)
        x = self.norm3(x)

        return x
        

In [62]:
# 3. Spatiotemporal Graph Transformer
class SpatiotemporalGraphTransformer(nn.Module):
    """
    Main model: Embeds spatiotemporal graphs using transformer architecture
    
    """
    def __init__(self, node_feature_dim, d_model=256, nhead=8,
                 num_layers=6, dim_feedforward=1024, dropout=0.1,
                 num_classes=None, max_nodes=1000, max_time=500,
                 use_layer_norm=True, use_dropout=True):
        super().__init__()

        self.d_model = d_model
        self.num_classes = num_classes
        self.use_dropout = use_dropout

        # Input projection with dropout
        self.input_proj = nn.Sequential(
            nn.Linear(node_feature_dim, d_model),
            nn.LayerNorm(d_model) if use_layer_norm else nn.Identity(),
            nn.Dropout(dropout) if use_dropout else nn.Identity()
        )

        # Positional encoding
        self.pos_encoder = GraphPositionalEncoding(d_model, max_nodes, max_time)

        # Transformer layers
        self.layers = nn.ModuleList([
            GraphTransformerLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_layers)
        ])

        # Pooling types for graph-level representation
        self.pooling_type = 'attention'  # 'mean', 'max', 'attention', 'hierarchical'

        if self.pooling_type == 'attention':
            self.attention_pool = nn.Sequential(
                nn.Linear(d_model, d_model // 2),
                nn.Tanh(),
                nn.Dropout(dropout),
                nn.Linear(d_model // 2, 1)
            )

        # Classification head (if num_classes provided) with stronger regularization
        if num_classes is not None:
            self.classifier = nn.Sequential(
                nn.Linear(d_model, d_model // 2),
                nn.LayerNorm(d_model // 2),
                nn.ReLU(),
                nn.Dropout(dropout * 1.5),  # Stronger dropout
                nn.Linear(d_model // 2, d_model // 4),
                nn.LayerNorm(d_model // 4),
                nn.ReLU(),
                nn.Dropout(dropout * 1.5),
                nn.Linear(d_model // 4, num_classes)
            )

    def forward(self, data, return_embeddings=False):
        """
        Args:
            data: PyG Data/Batch object with:
                - x: (num_nodes, node_feature_dim)
                - edge_index: (2, num_edges)
                - batch: (num_nodes,)
                - node_ids: (num_nodes,) - spatial identifiers
                - time_ids: (num_nodes,) - temporal identifiers
                - edge_attr: (num_edges, edge_dim) - optional
        Returns:
            logits: (batch_size, num_classes) or embeddings: (batch_size, d_model)
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None

        # Input projection
        x = self.input_proj(x)  # (num_nodes, d_model)

        # Add positional encoding
        if hasattr(data, 'node_ids') and hasattr(data, 'time_ids'):
            pos_enc = self.pos_encoder(x, data.node_ids, data.time_ids)
            x = x + pos_enc

        # Pass through transformer layers
        for layer in self.layers:
            x = layer(x, edge_index, batch, edge_attr)

        # Graph-level pooling
        graph_embeddings = self._pool_graph(x, batch)

        if return_embeddings:
            return graph_embeddings

        # Classification
        if self.num_classes is not None:
            logits = self.classifier(graph_embeddings)
            return logits

        return graph_embeddings

    def _pool_graph(self, x, batch):
        """Pool node features to graph-level representation"""
        if self.pooling_type == 'mean':
            return global_mean_pool(x, batch)

        elif self.pooling_type == 'max':
            return global_max_pool(x, batch)

        elif self.pooling_type == 'attention':
            # Attention-based pooling
            attention_scores = self.attention_pool(x)  # (num_nodes, 1)

            batch_size = batch.max().item() + 1
            graph_embeddings = []

            for i in range(batch_size):
                mask = (batch == i)
                node_features = x[mask]  # (nodes_in_graph, d_model)
                scores = attention_scores[mask]  # (nodes_in_graph, 1)

                # Softmax over nodes in this graph
                weights = F.softmax(scores, dim=0)

                # Weighted sum
                graph_emb = (node_features * weights).sum(dim=0)
                graph_embeddings.append(graph_emb)

            return torch.stack(graph_embeddings)

        else:
            raise ValueError(f"Unknown pooling type: {self.pooling_type}")


In [6]:
# 5. Training Pipeline: train_epoch
def train_epoch(model, dataloader, optimizer, criterion, device, epoch=None):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    # Create progress bar
    pbar = tqdm(dataloader, desc=f'Epoch {epoch} [Train]' if epoch else 'Training',
                leave=True, ncols=100)

    for batch_idx, batch in enumerate(pbar):
        batch = batch.to(device)

        # Forward pass
        logits = model(batch)
        loss = criterion(logits, batch.y.squeeze())

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping (important for stability on MPS)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Metrics - move to CPU for accumulation (avoids MPS memory issues)
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == batch.y.squeeze()).sum().item()
        total += batch.y.size(0)

        # Update progress bar
        current_loss = total_loss / (batch_idx + 1)
        current_acc = correct / total
        pbar.set_postfix({
            'loss': f'{current_loss:.4f}',
            'acc': f'{current_acc:.4f}'
        })

        # Clear cache periodically on MPS to avoid memory issues
        if device.type == 'mps' and total % 100 == 0:
            torch.mps.empty_cache()

    pbar.close()
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total

    return avg_loss, accuracy
# model evaluator: evaluate
@torch.no_grad()
def evaluate(model, dataloader, criterion, device, epoch=None, split='Val'):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    # Create progress bar
    pbar = tqdm(dataloader, desc=f'Epoch {epoch} [{split}]' if epoch else f'{split}',
                leave=True, ncols=100)

    for batch_idx, batch in enumerate(pbar):
        batch = batch.to(device)

        logits = model(batch)
        loss = criterion(logits, batch.y.squeeze())

        # Move metrics to CPU
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == batch.y.squeeze()).sum().item()
        total += batch.y.size(0)

        # Update progress bar
        current_loss = total_loss / (batch_idx + 1)
        current_acc = correct / total
        pbar.set_postfix({
            'loss': f'{current_loss:.4f}',
            'acc': f'{current_acc:.4f}'
        })

        # Clear cache on MPS
        if device.type == 'mps':
            torch.mps.empty_cache()

    pbar.close()
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total

    return avg_loss, accuracy



In [9]:
# 6. Complete Training Pipeline: create_example_data
def create_example_data(num_graphs=1000, num_nodes=50, num_timesteps=10,
                       node_feature_dim=32, num_classes=5):
    """
    Create example spatiotemporal graph data
    Each graph represents a trajectory with temporal snapshots
    """
    graphs = []
    labels = []

    for _ in range(num_graphs):
        # Create temporal graph
        total_nodes = num_nodes * num_timesteps

        # Node features (num_nodes * num_timesteps, feature_dim)
        x = torch.randn(total_nodes, node_feature_dim)
        print(x.shape)
        
        # Create spatiotemporal edges
        edge_index = []

        # Spatial edges (within each timestep)
        for t in range(num_timesteps):
            offset = t * num_nodes
            for i in range(num_nodes):
                # Connect to spatial neighbors
                neighbors = [(i + 1) % num_nodes, (i - 1) % num_nodes]
                for j in neighbors:
                    edge_index.append([offset + i, offset + j])

        # Temporal edges (across timesteps)
        for t in range(num_timesteps - 1):
            for i in range(num_nodes):
                curr_idx = t * num_nodes + i
                next_idx = (t + 1) * num_nodes + i
                edge_index.append([curr_idx, next_idx])
                edge_index.append([next_idx, curr_idx])  # bidirectional

        edge_index = torch.tensor(edge_index, dtype=torch.long).t()

        # Node and time IDs
        node_ids = torch.arange(num_nodes).repeat(num_timesteps)
        time_ids = torch.arange(num_timesteps).repeat_interleave(num_nodes)

        # Create Data object
        data = Data(x=x, edge_index=edge_index,
                   node_ids=node_ids, time_ids=time_ids)

        graphs.append(data)
        labels.append(torch.randint(0, num_classes, (1,)).item())

    return graphs, labels


In [10]:
create_example_data(num_graphs=1)

torch.Size([500, 32])


([Data(x=[500, 32], edge_index=[2, 1900], node_ids=[500], time_ids=[500])],
 [0])

In [57]:
# Load Features

basals = pd.read_pickle('/Users/xies/Library/CloudStorage/OneDrive-Stanford/Skin/Mesa et al/Lineage models/Dataset pickles/basals.pkl')
# basals['Region','Meta']

basals['Region','Meta'] = [s.split('_')[0] for s in basals.index.get_level_values('TrackID').values]
basals['TrackID','Meta'] = [s.split('_')[1] for s in basals.index.get_level_values('TrackID').values]

# Split R1 / R2
R1 = basals[basals['Region','Meta'] == 'R1']
R2 = basals[basals['Region','Meta'] == 'R2']

# Load graph
adjDicts = [np.load(f'/Users/xies/Library/CloudStorage/OneDrive-Stanford/Skin/Mesa et al/W-R1/Mastodon/basal_connectivity_3d/adjacenct_trackIDs_t{t}.npy'
                   , allow_pickle=True).item()
            for t in range(15)]


In [59]:
# Convert each time point to a spatial graph

import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
import numpy as np

adj_dict = np.load('/Users/allisonlam/Downloads/adjacenct_trackIDs_t0.npy', allow_pickle=True).item()

def dict_to_graph(adj_dict, undirected=True, node_features=None):
    """
    Convert a dictionary of adjacency lists into a PyTorch Geometric graph.

    Args:
        adj_dict (dict): {node_id: [neighbor_ids, ...]} mapping.
        undirected (bool): If True, add reverse edges.
        node_features (torch.Tensor or None): Optional tensor of shape [num_nodes, num_features].

    Returns:
        torch_geometric.data.Data: Graph data object.
    """
    # Build edge list
    edges = []
    for src, neighbors in adj_dict.items():
        for dst in neighbors:
            edges.append((src, dst))

    if not edges:
        raise ValueError("The adjacency dictionary has no edges.")

    # Convert to tensor [2, num_edges]
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    # Make graph undirected if requested
    if undirected:
        edge_index = to_undirected(edge_index)

    # Create Data object
    data = Data(x = list(adj_dict.keys()), edge_index=edge_index)

    # Optionally add node features
    if node_features is not None:
        data.x = node_features

    return data


In [None]:
# Connect each time point 