In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch_geometric
import networkx as nx
import metis
from scipy import sparse as sp
import time
from tqdm import tqdm
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.datasets import MNISTSuperpixels
from torch_sparse import SparseTensor
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
import os
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.get_device_name(0))
def LapPE(edge_index, pos_enc_dim, num_nodes):
    """
    Graph positional encoding using Laplacian eigenvectors.

    Args:
        edge_index (torch.Tensor): Edge indices of the graph.
        pos_enc_dim (int): Number of positional encoding dimensions.
        num_nodes (int): Number of nodes in the graph.

    Returns:
        torch.Tensor: Positional encodings of shape (num_nodes, pos_enc_dim).
    """
    # Compute degree
    degree = torch_geometric.utils.degree(edge_index[0], num_nodes)

    # Create adjacency matrix
    adj = torch.zeros((num_nodes, num_nodes), device=edge_index.device)
    
    adj[edge_index[0], edge_index[1]] = 1.0
    # Normalize adjacency with degree
    D_inv_sqrt = torch.diag(degree.clamp(min=1).pow(-0.5))
    L = torch.eye(num_nodes, device=edge_index.device) - D_inv_sqrt @ adj @ D_inv_sqrt

    # Compute eigenvalues and eigenvectors of Laplacian
    eigvals, eigvecs = torch.linalg.eigh(L)

    # Sort eigenvalues and eigenvectors in ascending order
    idx = eigvals.argsort()
    eigvecs = eigvecs[:, idx]

    # Extract first `pos_enc_dim` eigenvectors
    pos_enc = eigvecs[:, 1:pos_enc_dim+1]

    # Zero-pad if fewer eigenvectors are available
    if pos_enc.size(1) < pos_enc_dim:
        padding = pos_enc.new_zeros((num_nodes, pos_enc_dim - pos_enc.size(1)))
        pos_enc = torch.cat([pos_enc, padding], dim=1)
    #print(pos_enc)
    #print(pos_enc.shape)
    return pos_enc
    
def k_hop_subgraph(edge_index, num_nodes, num_hops, is_directed=False):
    # Returns k-hop subgraphs for all nodes in the graph
    if is_directed:
        row, col = edge_index
        birow, bicol = torch.cat([row, col]), torch.cat([col, row])
        edge_index = torch.stack([birow, bicol])
    else:
        row, col = edge_index
    
    sparse_adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))
    # Initialize masks and indicators for hop distances
    hop_masks = [torch.eye(num_nodes, dtype=torch.bool, device=edge_index.device)]
    hop_indicator = row.new_full((num_nodes, num_nodes), -1)
    hop_indicator[hop_masks[0]] = 0
    
    for i in range(num_hops):
        next_mask = sparse_adj.matmul(hop_masks[i].float()) > 0
        hop_masks.append(next_mask)
        hop_indicator[(hop_indicator == -1) & next_mask] = i + 1
    
    hop_indicator = hop_indicator.T  # N x N
    node_mask = (hop_indicator >= 0)  # N x N dense mask matrix
    return node_mask
    
# Function for METIS-based graph partitioning into subgraphs (patches)
def metis_subgraph(g, n_patches, drop_rate=0.0, num_hops=1, is_directed=False):
    # Check for directed or undirected configuration and partition the graph
    if is_directed:
        if g.num_nodes < n_patches:
            membership = torch.arange(g.num_nodes)
        else:
            G = to_networkx(g, to_undirected="lower")
            cuts, membership = metis.part_graph(G, n_patches, recursive=True)
    else:
        if g.num_nodes < n_patches:
            membership = torch.randperm(n_patches)
        else:
            # Data augmentation by edge dropping
            adjlist = g.edge_index.t()
            arr = torch.rand(len(adjlist))
            selected = arr > drop_rate
            G = nx.Graph()
            G.add_nodes_from(np.arange(g.num_nodes))
            G.add_edges_from(adjlist[selected].tolist())
            # Partition graph using METIS
            cuts, membership = metis.part_graph(G, n_patches, recursive=True)

    assert len(membership) >= g.num_nodes
    membership = torch.tensor(np.array(membership[:g.num_nodes]), device=g.edge_index.device)
    max_patch_id = torch.max(membership) + 1
    membership = membership + (n_patches - max_patch_id)

    # Create a mask for each node in each patch
    node_mask = torch.stack([membership == i for i in range(n_patches)])

    if num_hops > 0:
        subgraphs_batch, subgraphs_node_mapper = node_mask.nonzero(as_tuple=True)
        k_hop_node_mask = k_hop_subgraph(g.edge_index, g.num_nodes, num_hops, is_directed)
        node_mask.index_add_(0, subgraphs_batch, k_hop_node_mask[subgraphs_node_mapper])

    # Mask for edges that connect nodes within each patch
    edge_mask = node_mask[:, g.edge_index[0]] & node_mask[:, g.edge_index[1]]
    return node_mask, edge_mask


NVIDIA GeForce RTX 4080 SUPER


In [3]:

# GCN Model with Patch-Based Message Passing and Positional Encoding Injection
class PatchGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, pe_dim, patch_pe_dim, n_patches, num_hops=1, drop_rate=0.1):
        super(PatchGCN, self).__init__()
        # Adjusted GCN layer definitions to match the hidden_channels
        self.conv1 = GCNConv(hidden_channels, hidden_channels*2)
        self.conv2 = GCNConv(hidden_channels*2, hidden_channels*2)
        self.conv3 = GCNConv(hidden_channels*2, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, out_channels)
        self.T_node = nn.Linear(pe_dim, hidden_channels, bias=False)  # T^0
        self.U_node = nn.Linear(in_channels, hidden_channels, bias=True)  # U^0 with bias
        self.T_patch = nn.Linear(patch_pe_dim, hidden_channels, bias=False)  # T^hat
        self.U_patch = nn.Linear(hidden_channels, hidden_channels, bias=True)  # U^hat with bias
        self.patch_embedding_projection = nn.Linear(hidden_channels, hidden_channels, bias=True)
        self.n_patches = n_patches
        self.num_hops = num_hops
        self.drop_rate = drop_rate
        self.patch_pe_dim = patch_pe_dim
        self.pe_dim = pe_dim
        self.gelu = nn.GELU()
    def process_graph(self, graph):
        # METIS partitioning to extract patches (subgraphs)
        node_mask, edge_mask = metis_subgraph(graph, n_patches=self.n_patches, drop_rate=self.drop_rate, num_hops=self.num_hops)
        
        # Compute node-level positional encoding
        node_pos_enc = LapPE(graph.edge_index, pos_enc_dim=self.pe_dim, num_nodes=graph.num_nodes)
        """print("node PE: ")
        print(node_pos_enc)
        print(node_pos_enc.shape)"""
        # Compute patch-level positional encoding
        patch_adj = (node_mask.float() @ node_mask.float().T).long()
        #print("patch-level ad matrix: ")
        #print(patch_adj)
        #print(torch.nonzero(patch_adj, as_tuple=False).T)
        patch_pos_enc = LapPE(torch.nonzero(patch_adj, as_tuple=False).T, pos_enc_dim = self.patch_pe_dim , num_nodes = self.n_patches)
        #print("patch PE: ")
        #print(patch_pos_enc)
        #print(patch_pos_enc.shape)
        return node_pos_enc, patch_pos_enc, node_mask, edge_mask

    def forward(self, batch):
        all_patch_embeddings = []
        batch_graphs = batch.to_data_list()

        for graph in batch_graphs:
            # Process each graph into patches
            node_pos_enc, patch_pos_enc, node_mask, edge_mask = self.process_graph(graph)
            # Inject node positional encodings
            transformed_node_pos_enc = self.T_node(node_pos_enc)
            transformed_node_features = self.U_node(graph.x)
            node_features = transformed_node_pos_enc + transformed_node_features

            # GCN propagation within each patch
            patch_embeddings = []
            for i in range(self.n_patches):
                # Extract subgraph for each patch
                node_indices = node_mask[i].nonzero(as_tuple=True)[0]
                #print(node_indices)
                edge_indices = edge_mask[i].nonzero(as_tuple=True)[0]
                patch_x = node_features[node_indices]
                patch_edge_index = graph.edge_index[:, edge_indices]

                # Re-map edge indices to the local node indices in patch_x
                node_index_map = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_indices)}
                remapped_patch_edge_index = []
                for src, dst in patch_edge_index.t():
                    if src.item() in node_index_map and dst.item() in node_index_map:
                        remapped_patch_edge_index.append([node_index_map[src.item()], node_index_map[dst.item()]])
                
                if remapped_patch_edge_index:
                    remapped_patch_edge_index = torch.tensor(remapped_patch_edge_index, dtype=torch.long).T.to(patch_edge_index.device)
                else:
                    remapped_patch_edge_index = torch.empty((2, 0), dtype=torch.long, device=patch_edge_index.device)

                # Skip patches with no edges
                if remapped_patch_edge_index.size(1) > 0:
                    #print("GO")
                    x = self.gelu(self.conv1(patch_x, remapped_patch_edge_index))
                    x = self.gelu(self.conv2(x, remapped_patch_edge_index))
                    x = self.gelu(self.conv3(x, remapped_patch_edge_index))
                    x = self.gelu(self.conv4(x, remapped_patch_edge_index)) + patch_x
                else:
                    continue

                # Pooling to get patch embedding
                patch_embedding = torch.mean(x, dim=0)
                patch_embedding = self.patch_embedding_projection(patch_embedding)
                # Inject patch positional encodings
                transformed_patch_pos_enc = self.T_patch(patch_pos_enc[i])
                transformed_patch_embedding = self.U_patch(patch_embedding)
                final_patch_embedding = transformed_patch_pos_enc + transformed_patch_embedding
                patch_embeddings.append(final_patch_embedding)

            # Stack patch embeddings for each graph
            all_patch_embeddings.append(torch.stack(patch_embeddings, dim=0))

        # Combine all patch embeddings for the batch
        return torch.cat(all_patch_embeddings, dim=0)


In [4]:
class MixerLayer(nn.Module):
    def __init__(self, n_patches, hidden_channels, ds, dc):
        """
        Mixer Layer consisting of Token Mixer and Channel Mixer.

        Args:
            n_patches (int): Number of patches (P).
            hidden_channels (int): Hidden dimensionality (d).
            ds (int): Dimensionality of the intermediate space in token mixing (d_s).
            dc (int): Dimensionality of the intermediate space in channel mixing (d_c).
        """
        super(MixerLayer, self).__init__()
        self.layernorm = nn.LayerNorm(hidden_channels)

        # Token Mixer
        self.token_mixer_w1 = nn.Linear(n_patches, ds)
        self.token_mixer_w2 = nn.Linear(ds, n_patches)

        # Channel Mixer
        self.channel_mixer_w3 = nn.Linear(hidden_channels, dc)
        self.channel_mixer_w4 = nn.Linear(dc, hidden_channels)

        self.gelu = nn.GELU()

    def forward(self, x):
        """
        Forward pass of the Mixer Layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, n_patches, hidden_channels).

        Returns:
            torch.Tensor: Processed tensor of the same shape as the input.
        """
        # Token Mixing
        token_mixing_input = self.layernorm(x)  # Apply layer norm
        token_mixing_output = self.token_mixer_w1(token_mixing_input.permute(0, 2, 1))  # Permute for token mixing
        token_mixing_output = self.gelu(token_mixing_output)  # Apply GELU activation
        token_mixing_output = self.token_mixer_w2(token_mixing_output).permute(0, 2, 1)  # Permute back
        x = x + token_mixing_output  # Add residual connection

        # Channel Mixing
        channel_mixing_input = self.layernorm(x)  # Apply layer norm
        channel_mixing_output = self.channel_mixer_w3(channel_mixing_input)  # Apply first linear layer
        channel_mixing_output = self.gelu(channel_mixing_output)  # Apply GELU activation
        channel_mixing_output = self.channel_mixer_w4(channel_mixing_output)  # Apply second linear layer
        x = x + channel_mixing_output  # Add residual connection

        return x


In [5]:
class GraphMixer(nn.Module):
    def __init__(self, patch_gcn, n_patches, hidden_channels, ds, dc, num_classes, num_mixer_layers):
        """
        Graph Mixer model.

        Args:
            patch_gcn (PatchGCN): Instance of PatchGCN.
            n_patches (int): Number of patches (P).
            hidden_channels (int): Hidden dimensionality (d).
            ds (int): Dimensionality of the intermediate space in token mixing (d_s).
            dc (int): Dimensionality of the intermediate space in channel mixing (d_c).
            num_classes (int): Number of output classes.
            num_mixer_layers (int): Number of Mixer Layers.
        """
        super(GraphMixer, self).__init__()
        self.patch_gcn = patch_gcn
        self.mixer_layers = nn.ModuleList([
            MixerLayer(n_patches, hidden_channels, ds, dc) for _ in range(num_mixer_layers)
        ])
        self.classifier = nn.Linear(hidden_channels, num_classes)

    def forward(self, batch):
        """
        Forward pass of the Graph Mixer.

        Args:
            batch (torch_geometric.data.Batch): Input batch of graphs.

        Returns:
            torch.Tensor: Logits of shape (batch_size, num_classes).
        """
        patch_embeddings = self.patch_gcn(batch)  # Shape: (batch_size * n_patches, hidden_channels)
        #print("before batch", patch_embeddings)
        # Reshape patch embeddings for Mixer Layers
        #print("inside",batch.num_graphs)
        batch_size = batch.num_graphs
        patch_embeddings = patch_embeddings.view(batch_size, -1, patch_embeddings.size(-1))  # (batch_size, n_patches, hidden_channels)
        #print("after batch", patch_embeddings)
        # Apply Mixer Layers
        for mixer_layer in self.mixer_layers:
            patch_embeddings = mixer_layer(patch_embeddings)

        # Aggregate final patch embeddings for classification
        graph_embeddings = patch_embeddings.mean(dim=1)  # Mean pooling over patches (batch_size, hidden_channels)

        # Classify
        logits = self.classifier(graph_embeddings)  # (batch_size, num_classes)
        #print("output shape", logits.shape)
        return logits




In [6]:
trainset = MNISTSuperpixels(root='data/MNIST')
testset = MNISTSuperpixels(root='data/MNIST', transform=None, train = False)

In [7]:
#Hyperparameters
batch_size = 36
num_epochs = 100
learning_rate = 0.001
weight_decay = 1e-5

# PatchGCN Parameters
in_channels = trainset.num_node_features
hidden_channels = 128
out_channels = hidden_channels
pe_dim = 16
n_patches = 10
patch_pe_dim = n_patches - 1
num_hops = 1
drop_rate = 0.1

# Instantiate PatchGCN
patch_gcn = PatchGCN(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    pe_dim=pe_dim,
    patch_pe_dim=patch_pe_dim,
    n_patches=n_patches,
    num_hops=num_hops,
    drop_rate=drop_rate,
).to(device)

# GraphMixer Parameters
ds = 128  # Intermediate token mixing dimension
dc = 256  # Intermediate channel mixing dimension
num_classes = 10  # MNIST classification (digits 0-9)
num_mixer_layers = 4  # Number of Mixer Layers

# Instantiate GraphMixer
model = GraphMixer(
    patch_gcn=patch_gcn,
    n_patches=n_patches,
    hidden_channels=hidden_channels,
    ds=ds,
    dc=dc,
    num_classes=num_classes,
    num_mixer_layers=num_mixer_layers,
).to(device)


In [8]:
train_dataset, val_dataset = train_test_split(trainset, test_size=0.1)
test_dataset = testset

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [9]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)



In [10]:
train_losses = []
val_losses = []

# Training function with timing
def train():
    model.train()
    total_loss = 0
    batch_times = []  # Store batch processing times
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    for batch_idx, data in enumerate(progress_bar):
        start_time = time.time()  # Start timer for the batch

        data = data.to(device)  # Move data to GPU
        optimizer.zero_grad()
        logits = model(data)  # Forward pass
        loss = criterion(logits, data.y)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        end_time = time.time()  # End timer for the batch
        batch_times.append(end_time - start_time)  # Record batch time

        total_loss += loss.item()
        progress_bar.set_postfix({"Loss": loss.item()})  # Update tqdm bar

    avg_batch_time = sum(batch_times) / len(batch_times)  # Average batch time
    print(f"Average Batch Time: {avg_batch_time:.4f} seconds, Total Batches: {len(train_loader)}")
    return total_loss / len(train_loader)

# Validation function with timing
def validate():
    model.eval()
    total_loss = 0
    correct = 0
    progress_bar = tqdm(val_loader, desc="Validating", leave=False)
    for batch_idx, data in enumerate(progress_bar):
        start_time = time.time()

        data = data.to(device)
        logits = model(data)
        loss = criterion(logits, data.y)
        total_loss += loss.item()
        pred = logits.argmax(dim=1)  # Predictions
        correct += pred.eq(data.y).sum().item()

        end_time = time.time()
        progress_bar.set_postfix({"Batch Time (s)": f"{end_time - start_time:.4f}"})

    avg_loss = total_loss / len(val_loader)
    accuracy = correct / len(val_dataset)
    return avg_loss, accuracy

# Testing function with timing
def test():
    model.eval()
    correct = 0
    progress_bar = tqdm(test_loader, desc="Testing", leave=False)
    for batch_idx, data in enumerate(progress_bar):
        start_time = time.time()

        data = data.to(device)
        logits = model(data)
        pred = logits.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()

        end_time = time.time()
        progress_bar.set_postfix({"Batch Time (s)": f"{end_time - start_time:.4f}"})

    return correct / len(test_dataset)

# Training loop with timing and tqdm
for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs}")
    start_epoch_time = time.time()

    train_loss = train()
    val_loss, val_accuracy = validate()
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    test_accuracy = test()
    scheduler.step(val_loss)  # Update learning rate based on validation loss

    end_epoch_time = time.time()
    epoch_time = end_epoch_time - start_epoch_time  # Time taken for the epoch

    print(
        f"Epoch {epoch}, "
        f"Train Loss: {train_loss:.4f}, "
        f"Val Loss: {val_loss:.4f}, "
        f"Val Accuracy: {val_accuracy:.4f}, "
        f"Test Accuracy: {test_accuracy:.4f}, "
        f"Epoch Time: {epoch_time:.2f} seconds"
    )


Epoch 1/100


                                                                        

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss', marker='o')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', marker='o')
plt.title('Training vs Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.show()