<a href="https://colab.research.google.com/github/shubhii0206/Adaptive-Graph-Pooling-for-Protein-Structure-Classification-Using-GCNs/blob/main/GCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install required packages here
!pip3 install numpy
!pip3 install torch
!pip3 install networkx
!pip3 install matplotlib
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import os
import torch
import scipy
import seaborn
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, random_split
from torch_geometric.datasets import TUDataset
from sklearn.model_selection import train_test_split
from torch_geometric.utils import to_dense_adj
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling

In [None]:
# Implementation of GCN Model
class GCN(nn.Module):
    '''
    Implementation of GCN [Kipf et. al.] as Basic GNN module.
    '''
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = F.relu(x)
        return x

In [None]:
# Implementation of DownSamplePool Model
class DownSamplePool(nn.Module):
    '''
    Implementation of DownSample & Pool Module.
    '''
    def __init__(self, in_channels, out_channels, k):
        super(DownSamplePool, self).__init__()
        # Top-K pooling layer, which performs adaptive node selection (gPool layer equivalent)
        self.pool = TopKPooling(in_channels, ratio=k)

    def forward(self, x, edge_index, batch):
        # Down-sampling important nodes
        x, edge_index, _, batch, _, _ = self.pool(x, edge_index, None, batch)
        return x, edge_index, batch


In [None]:
class Model(nn.Module):
    '''
    Architecture of Overall Graph Classifier.
    '''
    def __init__(self, in_channels, hidden_channels, out_channels, num_classes, k1, k2):
        super(Model, self).__init__()

        # GNN layers before the first down-sample & pool
        self.gnn1 = GCN(in_channels, hidden_channels)
        self.gnn2 = GCN(hidden_channels, hidden_channels)

        # First Down-Sample & Pool
        self.pool1 = DownSamplePool(hidden_channels, hidden_channels, k=k1)

        # GNN layers before the second down-sample & pool
        self.gnn3 = GCN(hidden_channels, hidden_channels)
        self.gnn4 = GCN(hidden_channels, hidden_channels)

        # Second Down-Sample & Pool
        self.pool2 = DownSamplePool(hidden_channels, hidden_channels, k=k2)

        # Final Classification Head
        self.fc = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        # GNN Block 1
        x = self.gnn1(x, edge_index)
        x = self.gnn2(x, edge_index)

        # Down-Sample & Pool 1
        x, edge_index, batch = self.pool1(x, edge_index, batch)

        # GNN Block 2
        x = self.gnn3(x, edge_index)
        x = self.gnn4(x, edge_index)

        # Down-Sample & Pool 2
        x, edge_index, batch = self.pool2(x, edge_index, batch)

        # Global Pooling (Graph Level Readout)
        x = global_mean_pool(x, batch)

        # Final classification
        x = self.fc(x)

        return F.log_softmax(x, dim=1)

In [None]:
# Implementation of Additional Classes for loading dataset and preparing data loader (If Needed)
# Dataset loading and preparation
def load_dataset(name, split_ratio=(0.8, 0.1, 0.1)):
    dataset = TUDataset(root='/tmp/' + name, name=name)
    train_size = int(split_ratio[0] * len(dataset))
    val_size = int(split_ratio[1] * len(dataset))
    test_size = len(dataset) - train_size - val_size
    return random_split(dataset, [train_size, val_size, test_size])

# DataLoader preparation using PyTorch Geometric's DataLoader
def get_data_loaders(dataset_splits, batch_size=32):
    train_loader = DataLoader(dataset_splits[0], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset_splits[1], batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset_splits[2], batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

# Accuracy calculation function
def calculate_accuracy(loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch
            out = model(x, edge_index, batch_idx)
            _, predicted = out.max(1)
            total += batch.y.size(0)
            correct += (predicted == batch.y).sum().item()
    accuracy = (correct / total)*100
    return accuracy

In [None]:
def main(dataset_name, num_classes, epochs, k1, k2):
    # Hyperparameters
    # Get the number of features from the dataset
    dataset = TUDataset(root='/tmp/' + dataset_name, name=dataset_name)
    in_channels = dataset.num_node_features  # Get the actual number of features
    hidden_channels = 64
    out_channels = 64
    learning_rate = 0.001

    # Load and prepare dataset
    dataset_splits = load_dataset(dataset_name)
    train_loader, val_loader, test_loader = get_data_loaders(dataset_splits)

    # Initialize model, loss, optimizer
    model = Model(in_channels, hidden_channels, out_channels, num_classes, k1=k1, k2=k2)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(epochs):
        model.train()
        t = 0
        for batch in train_loader:
            x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch
            optimizer.zero_grad()
            out = model(x, edge_index, batch_idx)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            t += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {t/10:.4f}')

        # Validation accuracy
        val_accuracy = calculate_accuracy(val_loader, model)
        print(f'Validation Accuracy: {val_accuracy:.4f}')

    # Test accuracy
    test_accuracy = calculate_accuracy(test_loader, model)
    print(f'Test Accuracy: {test_accuracy:.4f}')
    return test_accuracy

In [None]:
if __name__ == "__main__":
    # Experiment configurations
    k_values = [0.9, 0.8, 0.6]  # K for downsampling layers
    m_values = [6, 3]  # M for hierarchical pooling layers

    # Dataset configurations
    datasets = {
        'DD': 2,  # Binary Classification
        'ENZYMES': 6  # 6-Class Classification
    }

    results = {}

    # Running experiments for both datasets
    for dataset_name, num_classes in datasets.items():
        print(f"\nTraining on {dataset_name} Dataset...")
        for k1 in k_values:
            for k2 in k_values:
                for m in m_values:
                    print(f"Running with k1={k1}, k2={k2}, m={m}")
                    test_accuracy = main(dataset_name, num_classes, epochs=10, k1=k1, k2=k2)
                    results[f'{dataset_name}_k1={k1}_k2={k2}_m={m}'] = test_accuracy

    # Print all results
    print("\nFinal Results:")
    for config, accuracy in results.items():
        print(f"{config}: Test Accuracy = {accuracy:.4f}")


Training on DD Dataset...
Running with k1=0.9, k2=0.9, m=6
Epoch 1/10, Loss: 2.1321
Validation Accuracy: 54.7009
Epoch 2/10, Loss: 2.0385
Validation Accuracy: 58.9744
Epoch 3/10, Loss: 2.0211
Validation Accuracy: 58.9744
Epoch 4/10, Loss: 1.9973
Validation Accuracy: 58.9744
Epoch 5/10, Loss: 1.9777
Validation Accuracy: 59.8291
Epoch 6/10, Loss: 1.9578
Validation Accuracy: 66.6667
Epoch 7/10, Loss: 1.9335
Validation Accuracy: 59.8291
Epoch 8/10, Loss: 1.9122
Validation Accuracy: 64.9573
Epoch 9/10, Loss: 1.9014
Validation Accuracy: 72.6496
Epoch 10/10, Loss: 1.8905
Validation Accuracy: 69.2308
Test Accuracy: 65.5462
Running with k1=0.9, k2=0.9, m=3
Epoch 1/10, Loss: 2.0759
Validation Accuracy: 63.2479
Epoch 2/10, Loss: 2.0312
Validation Accuracy: 63.2479
Epoch 3/10, Loss: 2.0311
Validation Accuracy: 63.2479
Epoch 4/10, Loss: 2.0041
Validation Accuracy: 63.2479
Epoch 5/10, Loss: 1.9857
Validation Accuracy: 63.2479
Epoch 6/10, Loss: 1.9846
Validation Accuracy: 63.2479
Epoch 7/10, Loss: 1

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Processing...
Done!


Epoch 1/10, Loss: 2.6929
Validation Accuracy: 16.6667
Epoch 2/10, Loss: 2.6922
Validation Accuracy: 16.6667
Epoch 3/10, Loss: 2.6907
Validation Accuracy: 21.6667
Epoch 4/10, Loss: 2.6815
Validation Accuracy: 20.0000
Epoch 5/10, Loss: 2.6742
Validation Accuracy: 20.0000
Epoch 6/10, Loss: 2.6599
Validation Accuracy: 20.0000
Epoch 7/10, Loss: 2.6417
Validation Accuracy: 25.0000
Epoch 8/10, Loss: 2.6343
Validation Accuracy: 16.6667
Epoch 9/10, Loss: 2.6326
Validation Accuracy: 26.6667
Epoch 10/10, Loss: 2.6336
Validation Accuracy: 25.0000
Test Accuracy: 18.3333
Running with k1=0.9, k2=0.9, m=3
Epoch 1/10, Loss: 2.6876
Validation Accuracy: 13.3333
Epoch 2/10, Loss: 2.6870
Validation Accuracy: 13.3333
Epoch 3/10, Loss: 2.6824
Validation Accuracy: 13.3333
Epoch 4/10, Loss: 2.6655
Validation Accuracy: 13.3333
Epoch 5/10, Loss: 2.6618
Validation Accuracy: 13.3333
Epoch 6/10, Loss: 2.6435
Validation Accuracy: 18.3333
Epoch 7/10, Loss: 2.6312
Validation Accuracy: 11.6667
Epoch 8/10, Loss: 2.6255
