In [None]:
import tinygrad.nn
import torch
import torch.nn
import torch.nn.functional as F
from tinygrad import Tensor
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GCNConv as PygConv

from tinygraph.nn import GCNConv
from tinygraph.nn.datasets import cora

In [None]:
data = cora()

In [None]:
class GNN:
    def __init__(self, in_features, hidden_features, out_features, dropout=0):
        self.conv1 = GCNConv(in_features, hidden_features)
        self.conv2 = GCNConv(hidden_features, out_features)
        self.dropout = dropout

    def __call__(self, x, edge_index):
        x = self.conv1(x, edge_index).relu().dropout(p=self.dropout)
        return self.conv2(x, edge_index).softmax(axis=1)

In [None]:
in_features = data.x.shape[1]
out_features = len(set(data.y.tolist()))
model = GNN(in_features, 16, out_features, dropout=0.5)
optimizer = tinygrad.nn.optim.AdamW(tinygrad.nn.state.get_parameters(model), lr=0.01, weight_decay=5e-4)

In [None]:
import numpy as np


def _mask(mask):
    return Tensor(np.nonzero(mask.numpy())[0])

train_mask, val_mask, test_mask = _mask(data.train_mask), _mask(data.val_mask), _mask(data.test_mask)

In [None]:
def train():
    Tensor.training = True
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    mask = train_mask
    loss = out[mask].cross_entropy(data.y[mask]).backward()
    optimizer.step()
    return loss.item()

In [None]:
def test():
    Tensor.training = False
    out = model(data.x, data.edge_index) 
    accuracies = []
    for mask in [train_mask, val_mask, test_mask]:
        pred = out[mask].argmax(axis=1)
        accuracy = pred.eq(data.y[mask]).sum() / mask.shape[0]
        accuracies.append(accuracy.item())
    return accuracies

In [None]:
class PygGNN(torch.nn.Module):
    def __init__(self, in_features, hidden_channels, out_features, dropout=0):
        super().__init__()
        self.conv1 = PygConv(in_features, hidden_channels)
        self.conv2 = PygConv(hidden_channels, out_features)
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.softmax(x, dim=1)

In [None]:
pyg_model = PygGNN(in_features, 16, out_features, dropout=0.5)
with torch.no_grad():
    pyg_model.conv1.lin.weight[:] = torch.tensor(model.conv1.weight.numpy())
    # pyg_model.conv1.bias[:] = torch.tensor(model.conv1.bias.numpy())
    pyg_model.conv2.lin.weight[:] = torch.tensor(model.conv2.weight.numpy())
    # pyg_model.conv2.bias[:] = torch.tensor(model.conv2.bias.numpy())
pyg_optimizer = torch.optim.Adam(pyg_model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
pyg_data = Planetoid(root="/tmp/Cora", name="Cora", transform=NormalizeFeatures())[0]

In [None]:
def pyg_train():
    pyg_model.train()
    pyg_optimizer.zero_grad()
    out = pyg_model(pyg_data.x, pyg_data.edge_index)
    loss = F.cross_entropy(out[pyg_data.train_mask], pyg_data.y[pyg_data.train_mask])
    loss.backward()
    pyg_optimizer.step()
    return loss.item()

In [None]:
def pyg_test():
    pyg_model.eval()
    logits = pyg_model(pyg_data.x, pyg_data.edge_index)
    accuracies = []
    for mask in [pyg_data.train_mask, pyg_data.val_mask, pyg_data.test_mask]:
        pred = logits[mask].max(dim=1)[1]
        acc = pred.eq(pyg_data.y[mask]).sum().item() / mask.sum().item()
        accuracies.append(acc)
    return accuracies

In [None]:
for epoch in range(1, 201):
    loss = train()
    if epoch % 10 == 0:
        train_acc, val_acc, test_acc = test()
        print(f"epoch {epoch:03d} loss {loss:.4f} train acc {train_acc:.4f} val acc {val_acc:.4f} test acc {test_acc:.4f}")

In [None]:
for epoch in range(1, 201):
    loss = pyg_train()
    if epoch % 10 == 0:
        train_acc, val_acc, test_acc = pyg_test()
        print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")