In [1]:
# Imports
import functools
import jax
import jax.numpy as jnp
import time
import jraph
import flax
import haiku as hk
import optax
import pickle
import numpy as np
import torch

from torch_geometric.data import Data, Dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset

from flax import linen as nn
from flax.training import train_state
import pathlib
import csv
import time
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [5]:
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU
from tqdm import tqdm

from torch_geometric.loader import RandomNodeLoader
from torch_geometric.nn import DeepGCNLayer, GENConv
from torch_geometric.utils import scatter

dataset = PygNodePropPredDataset('ogbn-proteins', root='/data101/makinen/ogbn/')
splitted_idx = dataset.get_idx_split()
data = dataset[0]
data.node_species = None
data.y = data.y.to(torch.float)

# Initialize features of nodes by aggregating edge features.
row, col = data.edge_index
data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')

# Set split indices to masks.
for split in ['train', 'valid', 'test']:
    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    mask[splitted_idx[split]] = True
    data[f'{split}_mask'] = mask

train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True,
                                num_workers=5)
test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5)


class DeeperGCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super().__init__()

        self.node_encoder = Linear(data.x.size(-1), hidden_channels)
        self.edge_encoder = Linear(data.edge_attr.size(-1), hidden_channels)

        self.layers = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            conv = GENConv(hidden_channels, hidden_channels, aggr='softmax',
                           t=1.0, learn_t=True, num_layers=2, norm='layer')
            norm = LayerNorm(hidden_channels, elementwise_affine=True)
            act = ReLU(inplace=True)

            layer = DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1,
                                 ckpt_grad=i % 3)
            self.layers.append(layer)

        self.lin = Linear(hidden_channels, data.y.size(-1))

    def forward(self, x, edge_index, edge_attr):
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        x = self.layers[0].conv(x, edge_index, edge_attr)

        for layer in self.layers[1:]:
            x = layer(x, edge_index, edge_attr)

        x = self.layers[0].act(self.layers[0].norm(x))
        x = F.dropout(x, p=0.1, training=self.training)

        return self.lin(x)
    
HIDDEN_CHANNELS = 64
NUM_LAYERS = 14

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeeperGCN(hidden_channels=HIDDEN_CHANNELS, num_layers=NUM_LAYERS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
evaluator = Evaluator('ogbn-proteins')


def train(epoch):
    model.train()

    pbar = tqdm(total=len(train_loader), position=0)
    pbar.set_description(f'Training epoch: {epoch:04d}')

    total_loss = total_examples = 0
    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * int(data.train_mask.sum())
        total_examples += int(data.train_mask.sum())

        pbar.update(1)

    pbar.close()

    return total_loss / total_examples


@torch.no_grad()
def test():
    model.eval()

    y_true = {'train': [], 'valid': [], 'test': []}
    y_pred = {'train': [], 'valid': [], 'test': []}

    pbar = tqdm(total=len(test_loader), position=0)
    pbar.set_description(f'Evaluating epoch: {epoch:04d}')

    for data in test_loader:
        
        data.edge_attr *= 0.5 #torch.rand(data.edge_attr.shape) #* (1.0 - 0.5)
        
        # sum edges again to restart the data
        row, col = data.edge_index
        data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr)

        for split in y_true.keys():
            mask = data[f'{split}_mask']
            y_true[split].append(data.y[mask].cpu())
            y_pred[split].append(out[mask].cpu())

        pbar.update(1)

    pbar.close()

    train_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['train'], dim=0),
        'y_pred': torch.cat(y_pred['train'], dim=0),
    })['rocauc']

    valid_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['valid'], dim=0),
        'y_pred': torch.cat(y_pred['valid'], dim=0),
    })['rocauc']

    test_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['test'], dim=0),
        'y_pred': torch.cat(y_pred['test'], dim=0),
    })['rocauc']

    return train_rocauc, valid_rocauc, test_rocauc


for epoch in range(1, 29):
    loss = train(epoch)
    train_rocauc, valid_rocauc, test_rocauc = test()
    print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')

Training epoch: 0001: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0001: 100%|██████████| 5/5 [00:23<00:00,  4.78s/it]


Loss: 0.3732, Train: 0.6789, Val: 0.5988, Test: 0.5428


Training epoch: 0002: 100%|██████████| 40/40 [00:54<00:00,  1.35s/it]
Evaluating epoch: 0002: 100%|██████████| 5/5 [00:23<00:00,  4.76s/it]


Loss: 0.3210, Train: 0.7263, Val: 0.7335, Test: 0.6734


Training epoch: 0003: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0003: 100%|██████████| 5/5 [00:23<00:00,  4.77s/it]


Loss: 0.3103, Train: 0.7399, Val: 0.7300, Test: 0.6849


Training epoch: 0004: 100%|██████████| 40/40 [00:54<00:00,  1.35s/it]
Evaluating epoch: 0004: 100%|██████████| 5/5 [00:23<00:00,  4.76s/it]


Loss: 0.3052, Train: 0.7519, Val: 0.7508, Test: 0.6955


Training epoch: 0005: 100%|██████████| 40/40 [00:53<00:00,  1.34s/it]
Evaluating epoch: 0005: 100%|██████████| 5/5 [00:23<00:00,  4.77s/it]


Loss: 0.3005, Train: 0.7603, Val: 0.7521, Test: 0.6976


Training epoch: 0001:   0%|          | 0/40 [11:30<?, ?it/s]
Evaluating epoch: 0002:   0%|          | 0/5 [09:00<?, ?it/s]


KeyboardInterrupt: 

In [6]:
outdir = '/data101/makinen/graph_fishnets/models/nullnet_nc_24_nlyr_14'

torch.save(model.state_dict(), outdir)

In [8]:
epoch

29