In [None]:
!pip install torch torchvision torchaudio
!pip install torch_geometric

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
import torch.utils.data

dataset_name = "NCI1"
dataset = TUDataset(root='./data', name=dataset_name)

print(f"Dataset: {dataset_name}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of classes: {dataset.num_classes}")


Downloading https://www.chrsmrrs.com/graphkerneldatasets/NCI1.zip
Processing...


Dataset: NCI1
Number of graphs: 4110
Number of node features: 37
Number of classes: 2


Done!


In [None]:
from torch.utils.data import random_split

num_train = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_train - num_val

train_dataset, val_dataset, test_dataset = random_split(dataset, [num_train, num_val, num_test])

print(f"Train graphs: {len(train_dataset)}, Val graphs: {len(val_dataset)}, Test graphs: {len(test_dataset)}")


Train graphs: 3288, Val graphs: 411, Test graphs: 411


In [None]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for batch in train_loader:
    print(f"Batch Size: {batch.batch.max().item() + 1}")
    print(f"Num Nodes in batch: {batch.num_nodes}")
    break


Batch Size: 32
Num Nodes in batch: 841


In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


In [None]:
class ProjectionMLP(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(ProjectionMLP, self).__init__()
        self.fc1 = torch.nn.Linear(in_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
import random
from torch_geometric.utils import dropout_adj

def node_dropping(data, drop_prob=0.2):
    num_nodes = data.num_nodes
    mask = torch.rand(num_nodes) > drop_prob
    mask_idx = torch.where(mask)[0]

    data.x = data.x[mask_idx]

    mapping = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(mask_idx)}
    new_edge_index = []
    for i in range(data.edge_index.shape[1]):
        u, v = data.edge_index[:, i]
        if u.item() in mapping and v.item() in mapping:
            new_edge_index.append([mapping[u.item()], mapping[v.item()]])

    data.edge_index = torch.tensor(new_edge_index, dtype=torch.long).T

    if hasattr(data, 'batch'):
        data.batch = data.batch[mask_idx]

    return data

def edge_perturbation(data, perturb_prob=0.2):
    data.edge_index, _ = dropout_adj(data.edge_index, p=perturb_prob)
    return data


In [None]:
def info_nce_loss(z1, z2, tau=0.5):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    similarity_matrix = torch.mm(z1, z2.T) / tau  # (batch_size, batch_size)

    labels = torch.arange(z1.shape[0])

    loss = F.cross_entropy(similarity_matrix, labels)

    return loss


In [None]:
from torch_geometric.nn import global_mean_pool

def train_graphcl(loader, gcn, projection, optimizer, epochs=100):
    gcn.train()
    projection.train()

    for epoch in range(epochs):
        for data in loader:
            optimizer.zero_grad()

            data_aug1 = node_dropping(data.clone(), drop_prob=0.2)
            data_aug2 = edge_perturbation(data.clone(), perturb_prob=0.2)

            h1 = gcn(data_aug1.x, data_aug1.edge_index)
            h2 = gcn(data_aug2.x, data_aug2.edge_index)

            if hasattr(data_aug1, 'batch') and hasattr(data_aug2, 'batch'):
                h1 = global_mean_pool(h1, data_aug1.batch)
                h2 = global_mean_pool(h2, data_aug2.batch)

            z1 = projection(h1)
            z2 = projection(h2)

            loss = info_nce_loss(z1, z2)

            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


In [None]:
gcn = GCN(in_channels=dataset.num_features, hidden_channels=128, out_channels=128)
projection_head = ProjectionMLP(in_dim=128, hidden_dim=64, out_dim=64)

optimizer_graphcl = torch.optim.Adam(list(gcn.parameters()) + list(projection_head.parameters()), lr=0.01)

train_graphcl(train_loader, gcn, projection_head, optimizer_graphcl, epochs=100)




Epoch 0, Loss: 2.1677
Epoch 10, Loss: 2.0122
Epoch 20, Loss: 1.8729
Epoch 30, Loss: 1.8580
Epoch 40, Loss: 2.2538
Epoch 50, Loss: 2.0988
Epoch 60, Loss: 1.8379
Epoch 70, Loss: 1.9233
Epoch 80, Loss: 1.8436
Epoch 90, Loss: 1.7996


In [None]:
import torch.nn as nn

class ClassifierMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_classes):
        super(ClassifierMLP, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
from torch_geometric.nn import global_mean_pool

def fine_tune(loader, gcn, classifier, optimizer, epochs=100):

    gcn.eval()
    classifier.train()

    for epoch in range(epochs):
        for data in loader:
            optimizer.zero_grad()

            with torch.no_grad():
                h = gcn(data.x, data.edge_index)
                h = global_mean_pool(h, data.batch)

            logits = classifier(h)
            loss = F.cross_entropy(logits, data.y)

            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


In [None]:
def evaluate(loader, gcn, classifier):
    gcn.eval()
    classifier.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for data in loader:
            h = gcn(data.x, data.edge_index)
            h = global_mean_pool(h, data.batch)
            logits = classifier(h)

            pred = logits.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)

    acc = correct / total
    print(f"Test Accuracy: {acc:.4f}")
    return acc


In [None]:

classifier = ClassifierMLP(in_dim=128, hidden_dim=64, num_classes=dataset.num_classes)

optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01)

fine_tune(train_loader, gcn, classifier, optimizer, epochs=100)

evaluate(test_loader, gcn, classifier)


Epoch 0, Loss: 0.5947
Epoch 10, Loss: 0.7257
Epoch 20, Loss: 0.8944
Epoch 30, Loss: 0.6368
Epoch 40, Loss: 0.6214
Epoch 50, Loss: 0.6099
Epoch 60, Loss: 0.4759
Epoch 70, Loss: 0.4935
Epoch 80, Loss: 0.6641
Epoch 90, Loss: 0.6243
Test Accuracy: 0.6350


0.635036496350365