In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn as nn
from torch_geometric.utils import remove_self_loops, degree

# Load Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='cora')
data = dataset[0]

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
data = data.to(device)  # Move data to GPU

def gcn_conv(h, edge_index):
    N = h.size(0)
    edge_index, _ = remove_self_loops(edge_index)
    src, dst = edge_index
    deg = degree(dst, num_nodes=N)
    deg_src = deg[src].pow(-0.5) 
    deg_src.masked_fill_(deg_src == float('inf'), 0)
    deg_dst = deg[dst].pow(-0.5)
    deg_dst.masked_fill_(deg_dst == float('inf'), 0)
    edge_weight = deg_src * deg_dst
    a = torch.sparse_coo_tensor(edge_index, edge_weight, torch.Size([N, N])).t()
    h_prime = a @ h 
    return h_prime


def gcn_conv_low_filter(h, edge_index):
    N = h.size(0)
    edge_index, _ = remove_self_loops(edge_index)
    src, dst = edge_index
    deg = degree(dst, num_nodes=N)
    deg_src = deg[src].pow(-0.5) 
    deg_src.masked_fill_(deg_src == float('inf'), 0)
    deg_dst = deg[dst].pow(-0.5)
    deg_dst.masked_fill_(deg_dst == float('inf'), 0)
    edge_weight = deg_src * deg_dst
    a = torch.sparse_coo_tensor(edge_index, edge_weight, torch.Size([N, N])).t()
    i_matrix = torch.eye(N).to(h.device)
    adj_matrix = i_matrix - a
    h_prime = adj_matrix @ h
    return h_prime

# Define GCN (teacher) model
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_feats, hidden_feats)
        self.conv2 = GCNConv(hidden_feats, out_feats)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

# Define MLP (student) model
class MLP(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_feats, hidden_feats)
        self.l1 = nn.Linear(hidden_feats, hidden_feats)
        self.l1_low = nn.Linear(hidden_feats, hidden_feats)
        self.fc2 = nn.Linear(hidden_feats, out_feats)
        self.l2 = nn.Linear(out_feats, out_feats)
        self.l2_low = nn.Linear(out_feats, out_feats)

    def forward(self, data):
        hidden_feat = []
        x = data.x
        x = F.relu(self.fc1(x))
        hidden_feat.append(x)
        temp_x = x
        x = self.l1(temp_x)
        x_low = self.l1_low(temp_x)
        hidden_feat.append(x)
        hidden_feat.append(x_low)
        # x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        hidden_feat.append(x)
        temp_x = x
        x = self.l2(temp_x)
        x_low = self.l2_low(temp_x)
        hidden_feat.append(x)
        hidden_feat.append(x_low)
        return x, hidden_feat

# Initialize models
teacher_model = GCN(in_feats=dataset.num_features, hidden_feats=16, out_feats=dataset.num_classes).to(device)
student_model = MLP(in_feats=dataset.num_features, hidden_feats=16, out_feats=dataset.num_classes).to(device)

# Optimizers
optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=0.01, weight_decay=0.01)
optimizer_student = torch.optim.Adam(student_model.parameters(), lr=0.01, weight_decay=0.001)

# Training hyperparameters
num_epochs = 100
temperature = 2.0
alpha = 0.7  # balance between soft target and true target
beta = 0.1
gamma = 0.01
teacher_save_path = "gcn_teacher.pth"
student_save_path = "mlp_student.pth"

# --- Step 1: Train the Teacher (GCN) ---
def train_teacher():
    teacher_model.train()
    optimizer_teacher.zero_grad()
    out = teacher_model(data)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer_teacher.step()

def evaluate_teacher():
    teacher_model.eval()
    out = teacher_model(data)
    pred = out.argmax(dim=1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    return acc

# Training the teacher
for epoch in range(num_epochs):
    train_teacher()
    acc = evaluate_teacher()
    print(f'Epoch {epoch}, Teacher Model Accuracy: {acc:.4f}')

# Save trained GCN teacher model
torch.save(teacher_model.state_dict(), teacher_save_path)

# --- Step 2: Distill Teacher Knowledge to Student (MLP) ---
teacher_model.load_state_dict(torch.load(teacher_save_path))
teacher_model.eval()

def calc_orth_loss(out1, out2):
    dot_product = (out1 * out2).sum(dim=1).pow(2)
    norm1 = out1.norm(dim=1)
    norm2 = out2.norm(dim=1)
    orth_loss = ((dot_product / (norm1 * norm2).clamp(min=1e-6)).mean())
    return orth_loss

def train_student(soft_labels):
    student_model.train()
    optimizer_student.zero_grad()
    out, hiddens = student_model(data)

    criterion = nn.MSELoss()
    h1, h1c, h1_low, h2, h2c, h2_low = hiddens
    layer_loss1 = criterion(gcn_conv(h1, data.edge_index), h1c) + criterion(gcn_conv(h2, data.edge_index), h2c)
    layer_loss2 = criterion(gcn_conv_low_filter(h1, data.edge_index), h1_low) + criterion(gcn_conv_low_filter(h2, data.edge_index), h2_low)

    orth_loss = calc_orth_loss(h1c, h1_low) + calc_orth_loss(h2c, h2_low)

    # Get the output from the student model for the training nodes
    student_output = out
    
    # Ensure the soft labels are only for the training nodes
    teacher_soft_labels = soft_labels
    
    # KL divergence for soft targets (distillation loss)
    loss_soft = F.kl_div(
        F.log_softmax(student_output / temperature, dim=1),
        teacher_soft_labels / temperature,
        reduction='batchmean'
    )
    
    # Cross-entropy loss for ground truth labels
    loss_hard = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    
    # Combined loss
    loss = alpha * loss_soft + (1 - alpha) * loss_hard + beta * (layer_loss1 + layer_loss2) + gamma * orth_loss
    loss.backward()
    optimizer_student.step()

def evaluate_student():
    student_model.eval()
    out, _ = student_model(data)
    pred = out.argmax(dim=1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    return acc

# Distillation: Train student model (MLP) using teacher's soft targets
for epoch in range(5000):
    # Get soft labels from the teacher
    with torch.no_grad():
        teacher_logits = teacher_model(data)
        soft_labels = F.softmax(teacher_logits / temperature, dim=1)  # Temperature scaling

    # Train the student with soft labels and ground truth
    train_student(soft_labels)
    acc = evaluate_student()
    print(f'Epoch {epoch}, Student Model Accuracy: {acc:.4f}')

# Save trained MLP student model
torch.save(student_model.state_dict(), student_save_path)

cpu
Epoch 0, Teacher Model Accuracy: 0.3220
Epoch 1, Teacher Model Accuracy: 0.4070
Epoch 2, Teacher Model Accuracy: 0.4470
Epoch 3, Teacher Model Accuracy: 0.4730
Epoch 4, Teacher Model Accuracy: 0.4810
Epoch 5, Teacher Model Accuracy: 0.5010
Epoch 6, Teacher Model Accuracy: 0.5310
Epoch 7, Teacher Model Accuracy: 0.5570
Epoch 8, Teacher Model Accuracy: 0.5920
Epoch 9, Teacher Model Accuracy: 0.6310
Epoch 10, Teacher Model Accuracy: 0.6640
Epoch 11, Teacher Model Accuracy: 0.6970
Epoch 12, Teacher Model Accuracy: 0.7160
Epoch 13, Teacher Model Accuracy: 0.7350
Epoch 14, Teacher Model Accuracy: 0.7490
Epoch 15, Teacher Model Accuracy: 0.7610
Epoch 16, Teacher Model Accuracy: 0.7780
Epoch 17, Teacher Model Accuracy: 0.7860
Epoch 18, Teacher Model Accuracy: 0.7890
Epoch 19, Teacher Model Accuracy: 0.7970
Epoch 20, Teacher Model Accuracy: 0.8010
Epoch 21, Teacher Model Accuracy: 0.8070
Epoch 22, Teacher Model Accuracy: 0.8100
Epoch 23, Teacher Model Accuracy: 0.8130
Epoch 24, Teacher Mode

KeyboardInterrupt: 