In [2]:
# Imports
import functools
import jax
import jax.numpy as jnp
import time
import jraph
import flax
import haiku as hk
import optax
import pickle
import numpy as np
import torch

from torch_geometric.data import Data, Dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset

from flax import linen as nn
from flax.training import train_state
import pathlib
import csv
import time
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [3]:
dataset = PygNodePropPredDataset(name = 'ogbn-proteins', root='/data101/makinen/ogbn/')
splitted_idx = dataset.get_idx_split()
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)

In [4]:
from torch_geometric.loader import RandomNodeLoader
from torch_geometric.utils import scatter


row, col = data.edge_index
data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')

# Set split indices to masks.
for split in ['train', 'valid', 'test']:
    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    mask[splitted_idx[split]] = True
    data[f'{split}_mask'] = mask

train_reader = RandomNodeLoader(data, num_parts=200, shuffle=True,
                                num_workers=0)
test_reader = RandomNodeLoader(data, num_parts=5, num_workers=0)

In [None]:
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU
from tqdm import tqdm

from torch_geometric.loader import RandomNodeLoader
from torch_geometric.nn import DeepGCNLayer, GENConv
from torch_geometric.utils import scatter

dataset = PygNodePropPredDataset('ogbn-proteins', root='/data101/makinen/ogbn/')
splitted_idx = dataset.get_idx_split()
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)

# Initialize features of nodes by aggregating edge features.
row, col = data.edge_index
data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')

# Set split indices to masks.
for split in ['train', 'valid', 'test']:
    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    mask[splitted_idx[split]] = True
    data[f'{split}_mask'] = mask

train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True,
                                num_workers=5)
test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5)


class DeeperGCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super().__init__()

        self.node_encoder = Linear(data.x.size(-1), hidden_channels)
        self.edge_encoder = Linear(data.edge_attr.size(-1), hidden_channels)

        self.layers = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            conv = GENConv(hidden_channels, hidden_channels, aggr='softmax',
                           t=1.0, learn_t=True, num_layers=2, norm='layer')
            norm = LayerNorm(hidden_channels, elementwise_affine=True)
            act = ReLU(inplace=True)

            layer = DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1,
                                 ckpt_grad=i % 3)
            self.layers.append(layer)

        self.lin = Linear(hidden_channels, data.y.size(-1))

    def forward(self, x, edge_index, edge_attr):
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        x = self.layers[0].conv(x, edge_index, edge_attr)

        for layer in self.layers[1:]:
            x = layer(x, edge_index, edge_attr)

        x = self.layers[0].act(self.layers[0].norm(x))
        x = F.dropout(x, p=0.1, training=self.training)

        return self.lin(x)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeeperGCN(hidden_channels=64, num_layers=28).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
evaluator = Evaluator('ogbn-proteins')


def train(epoch):
    model.train()

    pbar = tqdm(total=len(train_loader))
    pbar.set_description(f'Training epoch: {epoch:04d}')

    total_loss = total_examples = 0
    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * int(data.train_mask.sum())
        total_examples += int(data.train_mask.sum())

        pbar.update(1)

    pbar.close()

    return total_loss / total_examples


@torch.no_grad()
def test():
    model.eval()

    y_true = {'train': [], 'valid': [], 'test': []}
    y_pred = {'train': [], 'valid': [], 'test': []}

    pbar = tqdm(total=len(test_loader))
    pbar.set_description(f'Evaluating epoch: {epoch:04d}')

    for data in test_loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr)

        for split in y_true.keys():
            mask = data[f'{split}_mask']
            y_true[split].append(data.y[mask].cpu())
            y_pred[split].append(out[mask].cpu())

        pbar.update(1)

    pbar.close()

    train_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['train'], dim=0),
        'y_pred': torch.cat(y_pred['train'], dim=0),
    })['rocauc']

    valid_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['valid'], dim=0),
        'y_pred': torch.cat(y_pred['valid'], dim=0),
    })['rocauc']

    test_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['test'], dim=0),
        'y_pred': torch.cat(y_pred['test'], dim=0),
    })['rocauc']

    return train_rocauc, valid_rocauc, test_rocauc


for epoch in range(1, 1001):
    loss = train(epoch)
    train_rocauc, valid_rocauc, test_rocauc = test()
    print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')

Training epoch: 0001: 100%|██████████| 40/40 [00:53<00:00,  1.35s/it]
Evaluating epoch: 0001: 100%|██████████| 5/5 [00:36<00:00,  7.21s/it]


Loss: 0.3760, Train: 0.6770, Val: 0.5920, Test: 0.5667


Training epoch: 0002: 100%|██████████| 40/40 [00:54<00:00,  1.35s/it]
Evaluating epoch: 0002: 100%|██████████| 5/5 [00:34<00:00,  6.85s/it]


Loss: 0.3273, Train: 0.7118, Val: 0.6967, Test: 0.6436


Training epoch: 0003: 100%|██████████| 40/40 [00:54<00:00,  1.35s/it]
Evaluating epoch: 0003: 100%|██████████| 5/5 [00:34<00:00,  6.92s/it]


Loss: 0.3157, Train: 0.7329, Val: 0.7118, Test: 0.6762


Training epoch: 0004: 100%|██████████| 40/40 [00:53<00:00,  1.35s/it]
Evaluating epoch: 0004: 100%|██████████| 5/5 [00:34<00:00,  6.92s/it]


Loss: 0.3088, Train: 0.7513, Val: 0.7375, Test: 0.6886


Training epoch: 0005: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0005: 100%|██████████| 5/5 [00:34<00:00,  6.92s/it]


Loss: 0.3052, Train: 0.7643, Val: 0.7458, Test: 0.6963


Training epoch: 0006: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0006: 100%|██████████| 5/5 [00:34<00:00,  6.92s/it]


Loss: 0.2990, Train: 0.7808, Val: 0.7678, Test: 0.6955


Training epoch: 0007: 100%|██████████| 40/40 [00:53<00:00,  1.33s/it]
Evaluating epoch: 0007: 100%|██████████| 5/5 [00:34<00:00,  6.90s/it]


Loss: 0.2935, Train: 0.7856, Val: 0.7698, Test: 0.6932


Training epoch: 0008: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0008: 100%|██████████| 5/5 [00:34<00:00,  6.94s/it]


Loss: 0.2924, Train: 0.7853, Val: 0.7758, Test: 0.7016


Training epoch: 0009: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0009: 100%|██████████| 5/5 [00:34<00:00,  6.92s/it]


Loss: 0.2881, Train: 0.7969, Val: 0.7844, Test: 0.7105


Training epoch: 0010: 100%|██████████| 40/40 [00:54<00:00,  1.35s/it]
Evaluating epoch: 0010: 100%|██████████| 5/5 [00:34<00:00,  6.93s/it]


Loss: 0.2869, Train: 0.7975, Val: 0.7850, Test: 0.7161


Training epoch: 0011: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0011: 100%|██████████| 5/5 [00:34<00:00,  6.90s/it]


Loss: 0.2829, Train: 0.8003, Val: 0.7903, Test: 0.7171


Training epoch: 0012: 100%|██████████| 40/40 [00:54<00:00,  1.36s/it]
Evaluating epoch: 0012: 100%|██████████| 5/5 [00:34<00:00,  6.85s/it]


Loss: 0.2827, Train: 0.8036, Val: 0.7835, Test: 0.7232


Training epoch: 0013: 100%|██████████| 40/40 [00:54<00:00,  1.35s/it]
Evaluating epoch: 0013: 100%|██████████| 5/5 [00:34<00:00,  6.90s/it]


Loss: 0.2807, Train: 0.8050, Val: 0.7879, Test: 0.7248


Training epoch: 0014: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0014: 100%|██████████| 5/5 [00:34<00:00,  6.92s/it]


Loss: 0.2765, Train: 0.8167, Val: 0.7910, Test: 0.7266


Training epoch: 0015: 100%|██████████| 40/40 [00:54<00:00,  1.35s/it]
Evaluating epoch: 0015: 100%|██████████| 5/5 [00:34<00:00,  6.91s/it]


Loss: 0.2739, Train: 0.8170, Val: 0.7926, Test: 0.7034


Training epoch: 0016: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0016: 100%|██████████| 5/5 [00:34<00:00,  6.89s/it]


Loss: 0.2747, Train: 0.8137, Val: 0.7951, Test: 0.7278


Training epoch: 0017: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0017: 100%|██████████| 5/5 [00:35<00:00,  7.08s/it]


Loss: 0.2725, Train: 0.8128, Val: 0.7887, Test: 0.7092


Training epoch: 0018: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0018: 100%|██████████| 5/5 [00:34<00:00,  6.91s/it]


Loss: 0.2703, Train: 0.8259, Val: 0.7893, Test: 0.7240


Training epoch: 0019: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0019:   0%|          | 0/5 [00:00<?, ?it/s]

In [7]:
loss

0.25946213826560355

In [8]:
epoch

29