In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch.optim import Optimizer
import math

# DEFINING THE CUSTOM ADAM OPTIMIZER
class CustomAdam(Optimizer):
    """
    A custom implementation of the Adam Optimizer.
    Formula:
    m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
    v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
    theta = theta - lr * m_t / (sqrt(v_t) + epsilon)
    """
    def __init__(self, params, lr=0.01, betas=(0.9, 0.999), eps=1e-8):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = dict(lr=lr, betas=betas, eps=eps)
        super(CustomAdam, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values (Momentum)
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values (Variance)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Update Momentum (First Moment)
                # m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                # Update Variance (Second Moment)
                # v_t = beta2 * v_{t-1} + (1 - beta2) * (g_t)^2
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Bias Correction 
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                step_size = group['lr'] / bias_correction1

                # Update Weights
                # p = p - step_size * (m_t / denom)
                p.data.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

# DATASET SETUP
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=T.NormalizeFeatures())
data = dataset[0]

# Link Prediction Split
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=True,
    add_negative_train_samples=False
)
train_data, val_data, test_data = transform(data)

# MODEL ARCHITECTURE
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # Increased hidden channels slightly for better capacity
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        # Added dropout to prevent overfitting now that the optimizer is stronger
        x = F.dropout(x, p=0.5, training=self.training)
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

model = Net(in_channels=dataset.num_features, hidden_channels=128, out_channels=64)

# USING THE CUSTOM ADAM OPTIMIZER
# Set LR, learning rate to 0.01 (Adam handles the decay internally)
optimizer = CustomAdam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

# TRAINING LOOP
def train():
    model.train()
    optimizer.zero_grad()
    
    z = model.encode(train_data.x, train_data.edge_index)

    # Negative Sampling
    pos_edge_index = train_data.edge_label_index
    neg_edge_index = torch.randint(0, train_data.num_nodes, pos_edge_index.size(), dtype=torch.long)
    
    edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
    # Labels: 1 for real, 0 for fake
    edge_label = torch.cat([torch.ones(pos_edge_index.size(1)), 
                            torch.zeros(neg_edge_index.size(1))], dim=0)

    out = model.decode(z, edge_label_index)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).sigmoid()
    preds = (out > 0.5).float()
    correct = (preds == data.edge_label).sum().item()
    return correct / data.edge_label.size(0)

# EXECUTION
print("Training with Custom Adam Optimizer...")
for epoch in range(1, 101):
    loss = train()
    val_acc = test(val_data)
    test_acc = test(test_data)
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

print("---")
print(f"Final Link Prediction Classification Accuracy: {test_acc:.4f}")

Training with Custom Adam Optimizer...
Epoch: 010, Loss: 0.6884, Val Acc: 0.5000, Test Acc: 0.5000
Epoch: 020, Loss: 0.6533, Val Acc: 0.5085, Test Acc: 0.5066
Epoch: 030, Loss: 0.5663, Val Acc: 0.6803, Test Acc: 0.6945
Epoch: 040, Loss: 0.5253, Val Acc: 0.6935, Test Acc: 0.7021
Epoch: 050, Loss: 0.5106, Val Acc: 0.6850, Test Acc: 0.7059
Epoch: 060, Loss: 0.4822, Val Acc: 0.7021, Test Acc: 0.7249
Epoch: 070, Loss: 0.4747, Val Acc: 0.7173, Test Acc: 0.7324
Epoch: 080, Loss: 0.4644, Val Acc: 0.7230, Test Acc: 0.7324
Epoch: 090, Loss: 0.4626, Val Acc: 0.7287, Test Acc: 0.7495
Epoch: 100, Loss: 0.4551, Val Acc: 0.7343, Test Acc: 0.7505
---
Final Link Prediction Classification Accuracy: 0.7505
