In [7]:
import os.path as osp
import time

import torch
import torch.nn.functional as F

from torch_geometric.datasets import Entities
from torch_geometric.nn import RGATConv


from import_data import import_data

dataset = "Euro28"
data, results = import_data(dataset)

data = data
labels = results[:,2]


class RGAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_relations):
        super().__init__()
        self.conv1 = RGATConv(in_channels, hidden_channels, num_relations)
        self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_type):
        x = self.conv1(x, edge_index, edge_type).relu()
        x = self.conv2(x, edge_index, edge_type).relu()
        x = self.lin(x)
        return F.log_softmax(x, dim=-1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
model = RGAT(16, 16, dataset.num_classes, dataset.num_relations).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.edge_type)
    loss = F.nll_loss(out[data.train_idx], data.train_y)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index, data.edge_type).argmax(dim=-1)
    train_acc = float((pred[data.train_idx] == data.train_y).float().mean())
    test_acc = float((pred[data.test_idx] == data.test_y).float().mean())
    return train_acc, test_acc


times = []
for epoch in range(1, 51):
    start = time.time()
    loss = train()
    train_acc, test_acc = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '
          f'Test: {test_acc:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Epoch: 001, Train: 0.1429, Val: 0.0580, Test: 0.0640
Epoch: 002, Train: 0.1429, Val: 0.0580, Test: 0.0640
Epoch: 003, Train: 0.1429, Val: 0.0580, Test: 0.0640
Epoch: 004, Train: 0.1786, Val: 0.0920, Test: 0.1230
Epoch: 005, Train: 0.4714, Val: 0.4060, Test: 0.4190
Epoch: 006, Train: 0.3929, Val: 0.4060, Test: 0.4190
Epoch: 007, Train: 0.3071, Val: 0.4060, Test: 0.4190
Epoch: 008, Train: 0.3714, Val: 0.4060, Test: 0.4190
Epoch: 009, Train: 0.3786, Val: 0.4060, Test: 0.4190
Epoch: 010, Train: 0.4786, Val: 0.4060, Test: 0.4190
Epoch: 011, Train: 0.5571, Val: 0.4060, Test: 0.4190
Epoch: 012, Train: 0.6286, Val: 0.4060, Test: 0.4190
Epoch: 013, Train: 0.6786, Val: 0.4140, Test: 0.3900
Epoch: 014, Train: 0.7357, Val: 0.4580, Test: 0.4540
Epoch: 015, Train: 0.7500, Val: 0.4860, Test: 0.4820
Epoch: 016, Train: 0.7571, Val: 0.5120, Test: 0.4930
Epoch: 017, Train: 0.8357, Val: 0.5620, Test: 0.5530
Epoch: 018, Train: 0.8857, Val: 0.7120, Test: 0.6950
Epoch: 019, Train: 0.9143, Val: 0.7380, Test: 