## IMPORT DATA

In [59]:
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU
from tqdm import tqdm

from torch_geometric.loader import RandomNodeLoader

In [60]:
import os.path as osp
import sys

import matplotlib.pyplot as plt
import torch
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch_geometric.loader import DataLoader

# Download and process data at './dataset/ogbg_molhiv/'
ogbnproteins_dataset = PygNodePropPredDataset('ogbn-proteins', root='../data')
ogbnarxiv_dataset = PygNodePropPredDataset(name='ogbn-arxiv', root= '../data')

## Prepare DataLoader

In [67]:
## DataLoader For ogbnproteins datasett
from torch_geometric.utils import scatter
from torch_geometric.loader import RandomNodeLoader
splitted_idx = ogbnproteins_dataset.get_idx_split()
data = ogbnproteins_dataset[0]

data.y = data.y.to(torch.float)

# Initialize features of nodes by aggregating edge features.
row, col = data.edge_index
data.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')

# Set split indices to masks.
for split in ['train', 'valid', 'test']:
    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    mask[splitted_idx[split]] = True
    data[f'{split}_mask'] = mask
    
train_loader = RandomNodeLoader(data, num_parts=40, shuffle=True,
                                num_workers=5)
test_loader = RandomNodeLoader(data, num_parts=5, num_workers=5)

In [68]:
data.x[data.train_mask].shape,data.y[data.train_mask].shape, data.x.size(-1)

(torch.Size([86619, 8]), torch.Size([86619, 112]), 8)

In [69]:
data['train_mask']

tensor([ True,  True,  True,  ..., False, False, False])

## Prepare Models

### GCN MODEL

In [63]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from ogb.nodeproppred import Evaluator
from tqdm import tqdm

    
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(data.x.size(-1), 64, data.y.size(-1)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)
criterion = torch.nn.BCEWithLogitsLoss()
evaluator = Evaluator('ogbn-proteins')


def train(epoch):
    model.train()

   

    total_loss = total_examples = 0
    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * int(data.train_mask.sum())
        total_examples += int(data.train_mask.sum())



    return total_loss / total_examples


@torch.no_grad()
def test():
    model.eval()

    y_true = {'train': [], 'valid': [], 'test': []}
    y_pred = {'train': [], 'valid': [], 'test': []}



    for data in test_loader:
        data = data.to(device)
        out = model(data.x, data.edge_index)

        for split in y_true.keys():
            mask = data[f'{split}_mask']
            y_true[split].append(data.y[mask].cpu())
            y_pred[split].append(out[mask].cpu())



    train_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['train'], dim=0),
        'y_pred': torch.cat(y_pred['train'], dim=0),
    })['rocauc']

    valid_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['valid'], dim=0),
        'y_pred': torch.cat(y_pred['valid'], dim=0),
    })['rocauc']

    test_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['test'], dim=0),
        'y_pred': torch.cat(y_pred['test'], dim=0),
    })['rocauc']

    return train_rocauc, valid_rocauc, test_rocauc


print(f'Training And Test ')
for epoch in tqdm(range(1, 101)):

    loss = train(epoch)
   
    train_rocauc, valid_rocauc, test_rocauc = test()
    print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')

Training And Test 



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [02:23<3:56:34, 143.38s/it][A

Loss: 0.9122, Train: 0.5007, Val: 0.3939, Test: 0.4202



  2%|▏         | 2/100 [05:05<4:12:09, 154.38s/it][A

Loss: 0.3953, Train: 0.5631, Val: 0.4793, Test: 0.4189



  3%|▎         | 3/100 [08:07<4:29:39, 166.80s/it][A

Loss: 0.3625, Train: 0.6236, Val: 0.5737, Test: 0.4823



  4%|▍         | 4/100 [5:25:50<203:23:26, 7627.15s/it][A

Loss: 0.3487, Train: 0.6285, Val: 0.5899, Test: 0.4983



  5%|▌         | 5/100 [9:24:52<265:10:09, 10048.52s/it][A

Loss: 0.3417, Train: 0.6255, Val: 0.5584, Test: 0.4515



  6%|▌         | 6/100 [9:45:57<184:23:57, 7062.10s/it] [A

Loss: 0.3373, Train: 0.6570, Val: 0.6155, Test: 0.4834



  7%|▋         | 7/100 [9:48:35<124:07:31, 4804.85s/it][A

Loss: 0.3334, Train: 0.6663, Val: 0.6619, Test: 0.5498



  8%|▊         | 8/100 [9:51:24<85:04:20, 3328.92s/it] [A

Loss: 0.3325, Train: 0.6770, Val: 0.6644, Test: 0.5333



  9%|▉         | 9/100 [9:53:51<59:00:34, 2334.44s/it][A

Loss: 0.3328, Train: 0.6823, Val: 0.6833, Test: 0.5488



 10%|█         | 10/100 [9:57:08<41:51:47, 1674.53s/it][A

Loss: 0.3301, Train: 0.6001, Val: 0.5932, Test: 0.5233



 11%|█         | 11/100 [9:59:59<30:01:05, 1214.22s/it][A

Loss: 0.3303, Train: 0.6511, Val: 0.6285, Test: 0.5063



 12%|█▏        | 12/100 [10:03:52<22:23:03, 915.72s/it][A

Loss: 0.3281, Train: 0.6474, Val: 0.6526, Test: 0.5407



 13%|█▎        | 13/100 [10:07:43<17:07:12, 708.42s/it][A

Loss: 0.3302, Train: 0.6661, Val: 0.6627, Test: 0.5317



 14%|█▍        | 14/100 [10:10:49<13:09:10, 550.59s/it][A

Loss: 0.3281, Train: 0.6834, Val: 0.6854, Test: 0.5707



 15%|█▌        | 15/100 [10:13:41<10:18:28, 436.58s/it][A

Loss: 0.3274, Train: 0.6713, Val: 0.6812, Test: 0.5415



 16%|█▌        | 16/100 [10:16:31<8:18:37, 356.16s/it] [A

Loss: 0.3278, Train: 0.6342, Val: 0.5985, Test: 0.4966



 17%|█▋        | 17/100 [10:19:55<7:09:28, 310.47s/it][A

Loss: 0.3292, Train: 0.6567, Val: 0.6479, Test: 0.5268



 18%|█▊        | 18/100 [10:23:21<6:21:26, 279.11s/it][A

Loss: 0.3257, Train: 0.6626, Val: 0.6646, Test: 0.5229



 19%|█▉        | 19/100 [10:26:38<5:43:16, 254.28s/it][A

Loss: 0.3279, Train: 0.6615, Val: 0.6403, Test: 0.5237



 20%|██        | 20/100 [10:29:37<5:09:04, 231.81s/it][A

Loss: 0.3266, Train: 0.6760, Val: 0.6759, Test: 0.5369



 21%|██        | 21/100 [10:32:34<4:43:31, 215.33s/it][A

Loss: 0.3287, Train: 0.6617, Val: 0.6572, Test: 0.5376



 22%|██▏       | 22/100 [10:35:31<4:24:56, 203.81s/it][A

Loss: 0.3282, Train: 0.6742, Val: 0.6771, Test: 0.5417



 23%|██▎       | 23/100 [10:38:38<4:15:02, 198.74s/it][A

Loss: 0.3257, Train: 0.6633, Val: 0.6483, Test: 0.5113



 24%|██▍       | 24/100 [10:41:36<4:04:01, 192.65s/it][A

Loss: 0.3265, Train: 0.6529, Val: 0.6272, Test: 0.4949



 25%|██▌       | 25/100 [10:44:27<3:52:36, 186.08s/it][A

Loss: 0.3265, Train: 0.6777, Val: 0.6953, Test: 0.5455



 26%|██▌       | 26/100 [10:47:41<3:52:35, 188.59s/it][A

Loss: 0.3253, Train: 0.6565, Val: 0.6476, Test: 0.5164



 27%|██▋       | 27/100 [10:50:54<3:50:49, 189.72s/it][A

Loss: 0.3252, Train: 0.6486, Val: 0.6378, Test: 0.5186



 28%|██▊       | 28/100 [10:53:52<3:43:34, 186.31s/it][A

Loss: 0.3264, Train: 0.6680, Val: 0.6629, Test: 0.5131



 29%|██▉       | 29/100 [10:56:44<3:35:13, 181.88s/it][A

Loss: 0.3273, Train: 0.6609, Val: 0.6373, Test: 0.4975



 30%|███       | 30/100 [10:59:33<3:27:40, 178.01s/it][A

Loss: 0.3242, Train: 0.6688, Val: 0.6730, Test: 0.5446



 31%|███       | 31/100 [11:02:38<3:27:10, 180.16s/it][A

Loss: 0.3237, Train: 0.6432, Val: 0.6203, Test: 0.4887



 32%|███▏      | 32/100 [11:05:40<3:25:00, 180.89s/it][A

Loss: 0.3253, Train: 0.6831, Val: 0.6957, Test: 0.5465



 33%|███▎      | 33/100 [11:08:33<3:19:20, 178.51s/it][A

Loss: 0.3255, Train: 0.6588, Val: 0.6528, Test: 0.5074



 34%|███▍      | 34/100 [11:11:28<3:15:04, 177.35s/it][A

Loss: 0.3259, Train: 0.6586, Val: 0.6608, Test: 0.5362



 35%|███▌      | 35/100 [11:14:22<3:10:55, 176.24s/it][A

Loss: 0.3249, Train: 0.6873, Val: 0.6914, Test: 0.5605



 36%|███▌      | 36/100 [11:17:00<3:02:21, 170.97s/it][A

Loss: 0.3270, Train: 0.6826, Val: 0.6803, Test: 0.5346



 37%|███▋      | 37/100 [11:19:55<3:00:49, 172.22s/it][A

Loss: 0.3248, Train: 0.6977, Val: 0.6963, Test: 0.5502



 38%|███▊      | 38/100 [11:22:50<2:58:33, 172.80s/it][A

Loss: 0.3257, Train: 0.6754, Val: 0.6876, Test: 0.5546



 39%|███▉      | 39/100 [11:25:59<3:00:50, 177.88s/it][A

Loss: 0.3241, Train: 0.6577, Val: 0.6515, Test: 0.5125



 40%|████      | 40/100 [11:28:45<2:54:15, 174.26s/it][A

Loss: 0.3263, Train: 0.6473, Val: 0.6144, Test: 0.4821



 41%|████      | 41/100 [11:31:51<2:54:41, 177.66s/it][A

Loss: 0.3248, Train: 0.6569, Val: 0.6524, Test: 0.5275



 42%|████▏     | 42/100 [11:34:47<2:51:14, 177.14s/it][A

Loss: 0.3228, Train: 0.6699, Val: 0.6674, Test: 0.5152



 43%|████▎     | 43/100 [11:37:32<2:44:48, 173.49s/it][A

Loss: 0.3259, Train: 0.6435, Val: 0.6098, Test: 0.4822



 44%|████▍     | 44/100 [11:40:39<2:45:55, 177.77s/it][A

Loss: 0.3255, Train: 0.6595, Val: 0.6559, Test: 0.5278



 45%|████▌     | 45/100 [11:43:38<2:43:09, 178.00s/it][A

Loss: 0.3252, Train: 0.6811, Val: 0.6866, Test: 0.5517



 46%|████▌     | 46/100 [11:46:56<2:45:32, 183.94s/it][A

Loss: 0.3236, Train: 0.6607, Val: 0.6559, Test: 0.5255



 47%|████▋     | 47/100 [11:49:50<2:39:52, 181.00s/it][A

Loss: 0.3248, Train: 0.6469, Val: 0.6357, Test: 0.5040



 48%|████▊     | 48/100 [11:52:37<2:33:10, 176.73s/it][A

Loss: 0.3258, Train: 0.6621, Val: 0.6411, Test: 0.5135



 49%|████▉     | 49/100 [11:55:43<2:32:33, 179.48s/it][A

Loss: 0.3253, Train: 0.6683, Val: 0.6740, Test: 0.5470



 50%|█████     | 50/100 [11:58:57<2:33:17, 183.95s/it][A

Loss: 0.3236, Train: 0.6825, Val: 0.6932, Test: 0.5681



 51%|█████     | 51/100 [12:02:06<2:31:26, 185.44s/it][A

Loss: 0.3259, Train: 0.6827, Val: 0.6809, Test: 0.5383



 52%|█████▏    | 52/100 [12:05:03<2:26:14, 182.81s/it][A

Loss: 0.3252, Train: 0.6666, Val: 0.6756, Test: 0.5418



 53%|█████▎    | 53/100 [12:08:05<2:23:03, 182.63s/it][A

Loss: 0.3253, Train: 0.6683, Val: 0.6722, Test: 0.5397



 54%|█████▍    | 54/100 [12:10:58<2:17:46, 179.71s/it][A

Loss: 0.3232, Train: 0.6250, Val: 0.5976, Test: 0.4734



 55%|█████▌    | 55/100 [12:13:42<2:11:23, 175.20s/it][A

Loss: 0.3256, Train: 0.6768, Val: 0.6773, Test: 0.5337



 56%|█████▌    | 56/100 [12:16:41<2:09:14, 176.25s/it][A

Loss: 0.3264, Train: 0.6805, Val: 0.6829, Test: 0.5400



 57%|█████▋    | 57/100 [12:19:51<2:09:21, 180.51s/it][A

Loss: 0.3231, Train: 0.6727, Val: 0.6806, Test: 0.5641



 58%|█████▊    | 58/100 [12:23:00<2:08:05, 183.00s/it][A

Loss: 0.3265, Train: 0.6698, Val: 0.6733, Test: 0.5414



 59%|█████▉    | 59/100 [12:25:59<2:04:10, 181.73s/it][A

Loss: 0.3237, Train: 0.6501, Val: 0.6425, Test: 0.5141



 60%|██████    | 60/100 [12:28:50<1:59:01, 178.53s/it][A

Loss: 0.3244, Train: 0.6693, Val: 0.6685, Test: 0.5170



 61%|██████    | 61/100 [12:31:56<1:57:30, 180.79s/it][A

Loss: 0.3264, Train: 0.6821, Val: 0.6945, Test: 0.5399



 62%|██████▏   | 62/100 [12:34:49<1:53:02, 178.49s/it][A

Loss: 0.3267, Train: 0.6803, Val: 0.6883, Test: 0.5640



 63%|██████▎   | 63/100 [12:38:09<1:53:58, 184.82s/it][A

Loss: 0.3243, Train: 0.6590, Val: 0.6455, Test: 0.5194



 64%|██████▍   | 64/100 [12:41:14<1:50:55, 184.88s/it][A

Loss: 0.3260, Train: 0.6763, Val: 0.6771, Test: 0.5442



 65%|██████▌   | 65/100 [12:44:11<1:46:32, 182.65s/it][A

Loss: 0.3253, Train: 0.6637, Val: 0.6706, Test: 0.5511



 66%|██████▌   | 66/100 [12:47:53<1:50:12, 194.48s/it][A

Loss: 0.3235, Train: 0.6363, Val: 0.6239, Test: 0.5208



 67%|██████▋   | 67/100 [12:51:21<1:49:04, 198.32s/it][A

Loss: 0.3246, Train: 0.6738, Val: 0.6817, Test: 0.5346



 68%|██████▊   | 68/100 [12:54:10<1:41:09, 189.68s/it][A

Loss: 0.3244, Train: 0.6789, Val: 0.7006, Test: 0.5586



 69%|██████▉   | 69/100 [12:57:28<1:39:17, 192.17s/it][A

Loss: 0.3266, Train: 0.6811, Val: 0.6879, Test: 0.5514



 70%|███████   | 70/100 [13:01:27<1:43:06, 206.22s/it][A

Loss: 0.3241, Train: 0.6447, Val: 0.6107, Test: 0.4827



 71%|███████   | 71/100 [13:04:37<1:37:21, 201.43s/it][A

Loss: 0.3251, Train: 0.6533, Val: 0.6270, Test: 0.4978



 72%|███████▏  | 72/100 [13:07:30<1:29:55, 192.68s/it][A

Loss: 0.3242, Train: 0.6419, Val: 0.6169, Test: 0.4924



 73%|███████▎  | 73/100 [13:10:37<1:25:55, 190.94s/it][A

Loss: 0.3230, Train: 0.6573, Val: 0.6450, Test: 0.5246



 74%|███████▍  | 74/100 [13:13:28<1:20:09, 184.98s/it][A

Loss: 0.3241, Train: 0.6572, Val: 0.6593, Test: 0.5302



 75%|███████▌  | 75/100 [13:16:22<1:15:46, 181.86s/it][A

Loss: 0.3256, Train: 0.6601, Val: 0.6719, Test: 0.5511



 76%|███████▌  | 76/100 [13:19:43<1:14:57, 187.41s/it][A

Loss: 0.3254, Train: 0.6688, Val: 0.6825, Test: 0.5373



 77%|███████▋  | 77/100 [13:23:22<1:15:28, 196.88s/it][A

Loss: 0.3252, Train: 0.6707, Val: 0.6761, Test: 0.5483



 78%|███████▊  | 78/100 [13:26:40<1:12:23, 197.44s/it][A

Loss: 0.3254, Train: 0.6834, Val: 0.6975, Test: 0.5543



 79%|███████▉  | 79/100 [13:30:10<1:10:22, 201.08s/it][A

Loss: 0.3240, Train: 0.6852, Val: 0.6956, Test: 0.5503



 80%|████████  | 80/100 [13:33:07<1:04:34, 193.74s/it][A

Loss: 0.3252, Train: 0.6687, Val: 0.6810, Test: 0.5506



 81%|████████  | 81/100 [13:36:00<59:28, 187.80s/it]  [A

Loss: 0.3241, Train: 0.6580, Val: 0.6562, Test: 0.5088



 82%|████████▏ | 82/100 [13:39:25<57:49, 192.77s/it][A

Loss: 0.3246, Train: 0.6495, Val: 0.6276, Test: 0.5064



 83%|████████▎ | 83/100 [13:42:25<53:30, 188.88s/it][A

Loss: 0.3242, Train: 0.6744, Val: 0.6968, Test: 0.5449



 84%|████████▍ | 84/100 [13:45:30<50:07, 187.95s/it][A

Loss: 0.3244, Train: 0.6778, Val: 0.6864, Test: 0.5498



 85%|████████▌ | 85/100 [13:48:53<48:05, 192.35s/it][A

Loss: 0.3263, Train: 0.6806, Val: 0.6805, Test: 0.5336



 86%|████████▌ | 86/100 [13:51:54<44:04, 188.92s/it][A

Loss: 0.3235, Train: 0.6714, Val: 0.6846, Test: 0.5569



 87%|████████▋ | 87/100 [13:54:50<40:04, 184.93s/it][A

Loss: 0.3246, Train: 0.6712, Val: 0.6704, Test: 0.5440



 88%|████████▊ | 88/100 [13:57:53<36:53, 184.48s/it][A

Loss: 0.3235, Train: 0.6724, Val: 0.6825, Test: 0.5351



 89%|████████▉ | 89/100 [14:00:52<33:31, 182.82s/it][A

Loss: 0.3249, Train: 0.6480, Val: 0.6570, Test: 0.5347



 90%|█████████ | 90/100 [14:04:09<31:11, 187.16s/it][A

Loss: 0.3249, Train: 0.6553, Val: 0.6520, Test: 0.5250



 91%|█████████ | 91/100 [14:07:19<28:12, 188.04s/it][A

Loss: 0.3242, Train: 0.6739, Val: 0.6816, Test: 0.5583



 92%|█████████▏| 92/100 [14:10:29<25:09, 188.66s/it][A

Loss: 0.3234, Train: 0.6773, Val: 0.6868, Test: 0.5455



 93%|█████████▎| 93/100 [14:13:54<22:34, 193.44s/it][A

Loss: 0.3244, Train: 0.6509, Val: 0.6304, Test: 0.5097



 94%|█████████▍| 94/100 [14:16:49<18:46, 187.78s/it][A

Loss: 0.3238, Train: 0.6446, Val: 0.6166, Test: 0.4919



 95%|█████████▌| 95/100 [14:19:34<15:04, 180.94s/it][A

Loss: 0.3251, Train: 0.6536, Val: 0.6423, Test: 0.5136



 96%|█████████▌| 96/100 [14:22:26<11:53, 178.47s/it][A

Loss: 0.3258, Train: 0.6805, Val: 0.6815, Test: 0.5590



 97%|█████████▋| 97/100 [14:25:29<08:59, 179.89s/it][A

Loss: 0.3260, Train: 0.6489, Val: 0.6408, Test: 0.5345



 98%|█████████▊| 98/100 [14:28:19<05:53, 176.63s/it][A

Loss: 0.3237, Train: 0.6534, Val: 0.6424, Test: 0.5163



 99%|█████████▉| 99/100 [14:31:22<02:58, 178.62s/it][A

Loss: 0.3256, Train: 0.6605, Val: 0.6583, Test: 0.5390



100%|██████████| 100/100 [14:34:31<00:00, 524.71s/it][A

Loss: 0.3252, Train: 0.6505, Val: 0.6417, Test: 0.5106





### Node2vec Model

In [133]:
import os.path as osp
import sys
from tqdm import tqdm

from torch_geometric.nn import Node2Vec
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Node2Vec(
    data.edge_index,
    embedding_dim=data.y.size(-1),
    walk_length=20,
    context_size=10,
    walks_per_node=10,
    num_negative_samples=1,
    p=1.0,
    q=1.0,
    sparse=True,
).to(device)

num_workers = 4 if sys.platform == 'linux' else 0
loader = model.loader(batch_size=128, shuffle=True, num_workers=num_workers)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in loader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)



@torch.no_grad()
def test():
    model.eval()
    z = model()
    y_true = {'train': [], 'valid': [], 'test': []}
    y_pred = {'train': [], 'valid': [], 'test': []}

 
   
    for split in y_true.keys():
        mask = data[f'{split}_mask']
        y_true[split].append(data.y[mask].cpu())
        y_pred[split].append(z[mask].cpu())


    train_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['train'], dim=0),
        'y_pred': torch.cat(y_pred['train'], dim=0),
    })['rocauc']

    valid_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['valid'], dim=0),
        'y_pred': torch.cat(y_pred['valid'], dim=0),
    })['rocauc']

    test_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['test'], dim=0),
        'y_pred': torch.cat(y_pred['test'], dim=0),
    })['rocauc']

    return train_rocauc, valid_rocauc, test_rocauc




print(f'Training And Test ')
for epoch in tqdm(range(1, 21)):

    loss = train()
   
    train_rocauc, valid_rocauc, test_rocauc = test()
    print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')

Training And Test 


  5%|▌         | 1/20 [06:58<2:12:25, 418.20s/it]

Loss: 4.0124, Train: 0.4996, Val: 0.5008, Test: 0.5001


 10%|█         | 2/20 [14:42<2:13:35, 445.29s/it]

Loss: 1.2488, Train: 0.4991, Val: 0.5014, Test: 0.5002


 15%|█▌        | 3/20 [22:42<2:10:36, 461.00s/it]

Loss: 1.0819, Train: 0.4991, Val: 0.5021, Test: 0.5007


 20%|██        | 4/20 [30:41<2:04:52, 468.31s/it]

Loss: 1.0556, Train: 0.4992, Val: 0.5021, Test: 0.5013


 25%|██▌       | 5/20 [38:45<1:58:28, 473.93s/it]

Loss: 1.0482, Train: 0.4992, Val: 0.5014, Test: 0.5012


 30%|███       | 6/20 [46:46<1:51:09, 476.38s/it]

Loss: 1.0456, Train: 0.4991, Val: 0.5013, Test: 0.4995


 35%|███▌      | 7/20 [54:50<1:43:42, 478.65s/it]

Loss: 1.0447, Train: 0.4989, Val: 0.5010, Test: 0.5003


 40%|████      | 8/20 [1:02:52<1:35:59, 479.95s/it]

Loss: 1.0445, Train: 0.4986, Val: 0.5007, Test: 0.5010


 45%|████▌     | 9/20 [1:10:49<1:27:48, 478.94s/it]

Loss: 1.0446, Train: 0.4985, Val: 0.5002, Test: 0.5007


 50%|█████     | 10/20 [1:18:42<1:19:31, 477.17s/it]

Loss: 1.0444, Train: 0.4981, Val: 0.5003, Test: 0.4999


 55%|█████▌    | 11/20 [1:26:43<1:11:44, 478.26s/it]

Loss: 1.0445, Train: 0.4978, Val: 0.5000, Test: 0.5002


 60%|██████    | 12/20 [1:34:40<1:03:43, 477.89s/it]

Loss: 1.0444, Train: 0.4975, Val: 0.5001, Test: 0.4994


 65%|██████▌   | 13/20 [1:42:38<55:45, 477.95s/it]  

Loss: 1.0441, Train: 0.4973, Val: 0.5002, Test: 0.4996


 70%|███████   | 14/20 [1:50:36<47:47, 477.88s/it]

Loss: 1.0436, Train: 0.4971, Val: 0.4999, Test: 0.4999


 75%|███████▌  | 15/20 [1:58:37<39:54, 478.93s/it]

Loss: 1.0434, Train: 0.4967, Val: 0.5000, Test: 0.4992


 80%|████████  | 16/20 [2:06:55<32:18, 484.60s/it]

Loss: 1.0433, Train: 0.4966, Val: 0.4997, Test: 0.5000


 85%|████████▌ | 17/20 [2:14:55<24:10, 483.34s/it]

Loss: 1.0426, Train: 0.4962, Val: 0.4983, Test: 0.4981


 90%|█████████ | 18/20 [2:22:55<16:04, 482.33s/it]

Loss: 1.0424, Train: 0.4968, Val: 0.4979, Test: 0.4971


 95%|█████████▌| 19/20 [2:30:54<08:01, 481.33s/it]

Loss: 1.0423, Train: 0.4966, Val: 0.4985, Test: 0.4970


100%|██████████| 20/20 [2:38:57<00:00, 476.87s/it]

Loss: 1.0418, Train: 0.4956, Val: 0.4983, Test: 0.4969





## MLP model

In [99]:
from sklearn.neural_network import MLPClassifier
import numpy as np

X_train, X_valid, X_test = data.x[splitted_idx['train']], data.x[splitted_idx['valid']], data.x[splitted_idx['test']]

y_train, y_valid, y_test = data.y[splitted_idx['train']], data.y[splitted_idx['valid']], data.y[splitted_idx['test']]
X_train = torch.concat((X_train,X_valid))
y_train = torch.concat((y_train,y_valid))
model = MLPClassifier()
model.fit(X_train, y_train)

In [102]:
from sklearn.metrics import roc_auc_score
ypred_train = model.predict(data.x[splitted_idx['train']])
ypred_valid = model.predict(X_valid)
ypred_test = model.predict(X_test)

In [105]:
score_train = roc_auc_score(ypred_train, data.y[splitted_idx['train']])
score_valid = roc_auc_score(ypred_valid, y_valid)
score_test = roc_auc_score(ypred_test, y_test)
print(f"Train score {score_train} , Valid score {score_valid}, Test score {score_test}")

Train score 0.789416474889874 , Valid score 0.7830178677932498, Test score 0.5429647230357008


## GraphSAGE model

In [132]:
from torch_geometric.nn import GraphSAGE
model = GraphSAGE(
    data.num_node_features,
    hidden_channels=data.y.size(-1),
    num_layers=2,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train(epoch):
    model.train()

   

    total_loss = total_examples = 0
    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * int(data.train_mask.sum())
        total_examples += int(data.train_mask.sum())



    return total_loss / total_examples

@torch.no_grad()
def test():
    model.eval()

    y_true = {'train': [], 'valid': [], 'test': []}
    y_pred = {'train': [], 'valid': [], 'test': []}



    for data in test_loader:
        data = data.to(device)
        out = model(data.x, data.edge_index)

        for split in y_true.keys():
            mask = data[f'{split}_mask']
            y_true[split].append(data.y[mask].cpu())
            y_pred[split].append(out[mask].cpu())



    train_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['train'], dim=0),
        'y_pred': torch.cat(y_pred['train'], dim=0),
    })['rocauc']

    valid_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['valid'], dim=0),
        'y_pred': torch.cat(y_pred['valid'], dim=0),
    })['rocauc']

    test_rocauc = evaluator.eval({
        'y_true': torch.cat(y_true['test'], dim=0),
        'y_pred': torch.cat(y_pred['test'], dim=0),
    })['rocauc']

    return train_rocauc, valid_rocauc, test_rocauc


print(f'Training And Test ')
for epoch in tqdm(range(1, 101)):

    loss = train(epoch)
   
    train_rocauc, valid_rocauc, test_rocauc = test()
    print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '
          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')

Training And Test 


  1%|          | 1/100 [01:06<1:50:31, 66.99s/it]

Loss: 2.1480, Train: 0.6647, Val: 0.5807, Test: 0.5966


  2%|▏         | 2/100 [02:20<1:55:23, 70.65s/it]

Loss: 0.3783, Train: 0.6549, Val: 0.5993, Test: 0.6208


  3%|▎         | 3/100 [03:30<1:54:10, 70.62s/it]

Loss: 0.3375, Train: 0.6715, Val: 0.6231, Test: 0.6334


  4%|▍         | 4/100 [04:44<1:54:39, 71.66s/it]

Loss: 0.3251, Train: 0.7005, Val: 0.6441, Test: 0.6300


  5%|▌         | 5/100 [06:06<1:59:44, 75.63s/it]

Loss: 0.3191, Train: 0.7061, Val: 0.6607, Test: 0.6487


  6%|▌         | 6/100 [07:43<2:09:56, 82.94s/it]

Loss: 0.3134, Train: 0.7315, Val: 0.6776, Test: 0.6501


  7%|▋         | 7/100 [09:21<2:16:05, 87.80s/it]

Loss: 0.3095, Train: 0.7538, Val: 0.7047, Test: 0.6714


  8%|▊         | 8/100 [10:57<2:18:41, 90.45s/it]

Loss: 0.3044, Train: 0.7531, Val: 0.7107, Test: 0.6753


  9%|▉         | 9/100 [12:30<2:18:20, 91.22s/it]

Loss: 0.2989, Train: 0.7495, Val: 0.7051, Test: 0.6767


 10%|█         | 10/100 [14:01<2:16:44, 91.16s/it]

Loss: 0.2984, Train: 0.7560, Val: 0.7144, Test: 0.6804


 11%|█         | 11/100 [15:32<2:14:53, 90.94s/it]

Loss: 0.2962, Train: 0.7659, Val: 0.7195, Test: 0.6910


 12%|█▏        | 12/100 [17:04<2:13:54, 91.30s/it]

Loss: 0.2958, Train: 0.7647, Val: 0.7120, Test: 0.6801


 13%|█▎        | 13/100 [18:41<2:15:10, 93.22s/it]

Loss: 0.2936, Train: 0.7710, Val: 0.7210, Test: 0.6920


 14%|█▍        | 14/100 [20:16<2:14:13, 93.64s/it]

Loss: 0.2927, Train: 0.7676, Val: 0.7149, Test: 0.6863


 15%|█▌        | 15/100 [21:54<2:14:36, 95.02s/it]

Loss: 0.2928, Train: 0.7752, Val: 0.7207, Test: 0.7015


 16%|█▌        | 16/100 [23:30<2:13:32, 95.39s/it]

Loss: 0.2905, Train: 0.7718, Val: 0.7160, Test: 0.6938


 17%|█▋        | 17/100 [25:08<2:12:49, 96.01s/it]

Loss: 0.2881, Train: 0.7761, Val: 0.7213, Test: 0.6986


 18%|█▊        | 18/100 [26:46<2:11:54, 96.52s/it]

Loss: 0.2863, Train: 0.7782, Val: 0.7189, Test: 0.6962


 19%|█▉        | 19/100 [28:23<2:10:42, 96.82s/it]

Loss: 0.2866, Train: 0.7770, Val: 0.7208, Test: 0.6972


 20%|██        | 20/100 [29:54<2:06:35, 94.95s/it]

Loss: 0.2854, Train: 0.7919, Val: 0.7394, Test: 0.7028


 21%|██        | 21/100 [31:21<2:02:09, 92.78s/it]

Loss: 0.2849, Train: 0.7895, Val: 0.7336, Test: 0.7075


 22%|██▏       | 22/100 [32:40<1:55:11, 88.60s/it]

Loss: 0.2868, Train: 0.7814, Val: 0.7261, Test: 0.6961


 23%|██▎       | 23/100 [33:57<1:49:16, 85.14s/it]

Loss: 0.2843, Train: 0.7966, Val: 0.7357, Test: 0.7007


 24%|██▍       | 24/100 [35:16<1:45:24, 83.22s/it]

Loss: 0.2829, Train: 0.7901, Val: 0.7339, Test: 0.7024


 25%|██▌       | 25/100 [36:33<1:41:42, 81.37s/it]

Loss: 0.2835, Train: 0.8000, Val: 0.7349, Test: 0.7007


 26%|██▌       | 26/100 [37:50<1:38:43, 80.05s/it]

Loss: 0.2827, Train: 0.7872, Val: 0.7259, Test: 0.6986


 27%|██▋       | 27/100 [39:10<1:37:27, 80.10s/it]

Loss: 0.2822, Train: 0.7952, Val: 0.7426, Test: 0.7133


 28%|██▊       | 28/100 [40:36<1:37:56, 81.62s/it]

Loss: 0.2822, Train: 0.8003, Val: 0.7440, Test: 0.7167


 29%|██▉       | 29/100 [41:57<1:36:40, 81.69s/it]

Loss: 0.2809, Train: 0.7995, Val: 0.7438, Test: 0.7101


 30%|███       | 30/100 [43:17<1:34:30, 81.01s/it]

Loss: 0.2818, Train: 0.8062, Val: 0.7498, Test: 0.7227


 31%|███       | 31/100 [44:35<1:32:18, 80.27s/it]

Loss: 0.2798, Train: 0.8029, Val: 0.7486, Test: 0.7225


 32%|███▏      | 32/100 [45:55<1:30:47, 80.11s/it]

Loss: 0.2801, Train: 0.7999, Val: 0.7548, Test: 0.7271


 33%|███▎      | 33/100 [47:18<1:30:32, 81.09s/it]

Loss: 0.2803, Train: 0.7965, Val: 0.7524, Test: 0.7229


 34%|███▍      | 34/100 [48:42<1:29:50, 81.68s/it]

Loss: 0.2795, Train: 0.8051, Val: 0.7493, Test: 0.7261


 35%|███▌      | 35/100 [50:03<1:28:23, 81.60s/it]

Loss: 0.2784, Train: 0.8101, Val: 0.7495, Test: 0.7307


 36%|███▌      | 36/100 [51:26<1:27:32, 82.06s/it]

Loss: 0.2775, Train: 0.8094, Val: 0.7544, Test: 0.7339


 37%|███▋      | 37/100 [52:47<1:25:50, 81.76s/it]

Loss: 0.2776, Train: 0.8096, Val: 0.7617, Test: 0.7376


 38%|███▊      | 38/100 [54:13<1:25:52, 83.10s/it]

Loss: 0.2772, Train: 0.8163, Val: 0.7624, Test: 0.7372


 39%|███▉      | 39/100 [55:37<1:24:38, 83.26s/it]

Loss: 0.2767, Train: 0.8133, Val: 0.7603, Test: 0.7395


 40%|████      | 40/100 [56:58<1:22:30, 82.51s/it]

Loss: 0.2771, Train: 0.8149, Val: 0.7619, Test: 0.7381


 41%|████      | 41/100 [58:16<1:19:44, 81.09s/it]

Loss: 0.2769, Train: 0.8116, Val: 0.7583, Test: 0.7369


 42%|████▏     | 42/100 [59:34<1:17:41, 80.38s/it]

Loss: 0.2762, Train: 0.8109, Val: 0.7578, Test: 0.7401


 43%|████▎     | 43/100 [1:00:57<1:17:04, 81.13s/it]

Loss: 0.2776, Train: 0.8071, Val: 0.7485, Test: 0.7313


 44%|████▍     | 44/100 [1:02:20<1:16:12, 81.66s/it]

Loss: 0.2777, Train: 0.8096, Val: 0.7496, Test: 0.7340


 45%|████▌     | 45/100 [1:03:43<1:15:08, 81.98s/it]

Loss: 0.2763, Train: 0.8140, Val: 0.7594, Test: 0.7360


 46%|████▌     | 46/100 [1:04:49<1:09:30, 77.24s/it]

Loss: 0.2762, Train: 0.8089, Val: 0.7509, Test: 0.7367


 47%|████▋     | 47/100 [1:06:10<1:09:09, 78.30s/it]

Loss: 0.2765, Train: 0.8101, Val: 0.7565, Test: 0.7353


 48%|████▊     | 48/100 [1:07:33<1:09:13, 79.88s/it]

Loss: 0.2756, Train: 0.8099, Val: 0.7576, Test: 0.7379


 49%|████▉     | 49/100 [1:08:57<1:08:58, 81.15s/it]

Loss: 0.2765, Train: 0.8152, Val: 0.7641, Test: 0.7446


 50%|█████     | 50/100 [1:10:23<1:08:50, 82.60s/it]

Loss: 0.2758, Train: 0.8168, Val: 0.7623, Test: 0.7432


 51%|█████     | 51/100 [1:12:08<1:12:54, 89.28s/it]

Loss: 0.2761, Train: 0.8126, Val: 0.7600, Test: 0.7384


 52%|█████▏    | 52/100 [1:13:47<1:13:43, 92.15s/it]

Loss: 0.2758, Train: 0.8066, Val: 0.7615, Test: 0.7396


 53%|█████▎    | 53/100 [1:15:09<1:09:41, 88.98s/it]

Loss: 0.2760, Train: 0.8110, Val: 0.7586, Test: 0.7361


 54%|█████▍    | 54/100 [1:16:28<1:06:02, 86.15s/it]

Loss: 0.2754, Train: 0.8199, Val: 0.7688, Test: 0.7475


 55%|█████▌    | 55/100 [1:17:48<1:03:13, 84.30s/it]

Loss: 0.2750, Train: 0.8165, Val: 0.7675, Test: 0.7451


 56%|█████▌    | 56/100 [1:19:08<1:00:44, 82.83s/it]

Loss: 0.2749, Train: 0.8190, Val: 0.7682, Test: 0.7425


 57%|█████▋    | 57/100 [1:20:25<58:11, 81.19s/it]  

Loss: 0.2751, Train: 0.8178, Val: 0.7670, Test: 0.7432


 58%|█████▊    | 58/100 [1:21:43<56:13, 80.33s/it]

Loss: 0.2755, Train: 0.8163, Val: 0.7646, Test: 0.7422


 59%|█████▉    | 59/100 [1:23:17<57:38, 84.36s/it]

Loss: 0.2763, Train: 0.8162, Val: 0.7621, Test: 0.7384


 60%|██████    | 60/100 [1:25:33<1:06:29, 99.74s/it]

Loss: 0.2744, Train: 0.8170, Val: 0.7663, Test: 0.7405


 61%|██████    | 61/100 [1:27:07<1:03:42, 98.00s/it]

Loss: 0.2761, Train: 0.8167, Val: 0.7650, Test: 0.7427


 62%|██████▏   | 62/100 [1:28:31<59:31, 93.98s/it]  

Loss: 0.2745, Train: 0.8197, Val: 0.7611, Test: 0.7397


 63%|██████▎   | 63/100 [1:29:55<56:00, 90.81s/it]

Loss: 0.2751, Train: 0.8165, Val: 0.7584, Test: 0.7368


 64%|██████▍   | 64/100 [1:31:14<52:22, 87.28s/it]

Loss: 0.2772, Train: 0.8163, Val: 0.7615, Test: 0.7419


 65%|██████▌   | 65/100 [1:32:35<49:49, 85.41s/it]

Loss: 0.2744, Train: 0.8161, Val: 0.7536, Test: 0.7335


 66%|██████▌   | 66/100 [1:33:54<47:24, 83.66s/it]

Loss: 0.2755, Train: 0.8129, Val: 0.7605, Test: 0.7399


 67%|██████▋   | 67/100 [1:35:15<45:34, 82.88s/it]

Loss: 0.2762, Train: 0.8093, Val: 0.7514, Test: 0.7349


 68%|██████▊   | 68/100 [1:36:38<44:11, 82.85s/it]

Loss: 0.2738, Train: 0.8205, Val: 0.7655, Test: 0.7428


 69%|██████▉   | 69/100 [1:38:11<44:21, 85.85s/it]

Loss: 0.2744, Train: 0.8172, Val: 0.7608, Test: 0.7365


 70%|███████   | 70/100 [1:39:29<41:47, 83.58s/it]

Loss: 0.2758, Train: 0.8146, Val: 0.7515, Test: 0.7328


 71%|███████   | 71/100 [1:40:49<39:48, 82.35s/it]

Loss: 0.2771, Train: 0.8169, Val: 0.7589, Test: 0.7392


 72%|███████▏  | 72/100 [1:42:05<37:37, 80.62s/it]

Loss: 0.2741, Train: 0.8196, Val: 0.7691, Test: 0.7464


 73%|███████▎  | 73/100 [1:43:26<36:20, 80.77s/it]

Loss: 0.2738, Train: 0.8119, Val: 0.7565, Test: 0.7377


 74%|███████▍  | 74/100 [1:44:44<34:37, 79.90s/it]

Loss: 0.2772, Train: 0.8148, Val: 0.7569, Test: 0.7348


 75%|███████▌  | 75/100 [1:45:52<31:48, 76.36s/it]

Loss: 0.2753, Train: 0.8181, Val: 0.7639, Test: 0.7420


 76%|███████▌  | 76/100 [1:47:01<29:33, 73.89s/it]

Loss: 0.2739, Train: 0.8191, Val: 0.7612, Test: 0.7417


 77%|███████▋  | 77/100 [1:48:10<27:52, 72.70s/it]

Loss: 0.2762, Train: 0.8115, Val: 0.7589, Test: 0.7408


 78%|███████▊  | 78/100 [1:49:24<26:46, 73.04s/it]

Loss: 0.2739, Train: 0.8124, Val: 0.7635, Test: 0.7447


 79%|███████▉  | 79/100 [1:50:59<27:47, 79.39s/it]

Loss: 0.2737, Train: 0.8174, Val: 0.7550, Test: 0.7349


 80%|████████  | 80/100 [1:52:26<27:17, 81.87s/it]

Loss: 0.2732, Train: 0.8172, Val: 0.7479, Test: 0.7329


 81%|████████  | 81/100 [1:53:58<26:53, 84.91s/it]

Loss: 0.2737, Train: 0.8159, Val: 0.7577, Test: 0.7376


 82%|████████▏ | 82/100 [1:55:31<26:08, 87.14s/it]

Loss: 0.2729, Train: 0.8206, Val: 0.7731, Test: 0.7503


 83%|████████▎ | 83/100 [1:57:04<25:13, 89.05s/it]

Loss: 0.2747, Train: 0.8186, Val: 0.7604, Test: 0.7410


 84%|████████▍ | 84/100 [1:58:35<23:53, 89.60s/it]

Loss: 0.2731, Train: 0.8193, Val: 0.7682, Test: 0.7435


 85%|████████▌ | 85/100 [2:00:20<23:33, 94.26s/it]

Loss: 0.2742, Train: 0.8184, Val: 0.7634, Test: 0.7427


 86%|████████▌ | 86/100 [2:02:06<22:47, 97.67s/it]

Loss: 0.2740, Train: 0.8161, Val: 0.7547, Test: 0.7359


 87%|████████▋ | 87/100 [2:03:55<21:55, 101.17s/it]

Loss: 0.2743, Train: 0.8191, Val: 0.7544, Test: 0.7310


 88%|████████▊ | 88/100 [2:05:44<20:41, 103.43s/it]

Loss: 0.2735, Train: 0.8195, Val: 0.7623, Test: 0.7401


 89%|████████▉ | 89/100 [2:07:00<17:28, 95.30s/it] 

Loss: 0.2728, Train: 0.8160, Val: 0.7586, Test: 0.7389


 90%|█████████ | 90/100 [2:08:16<14:54, 89.43s/it]

Loss: 0.2739, Train: 0.8182, Val: 0.7618, Test: 0.7413


 91%|█████████ | 91/100 [2:09:44<13:21, 89.02s/it]

Loss: 0.2735, Train: 0.8057, Val: 0.7520, Test: 0.7405


 92%|█████████▏| 92/100 [2:11:30<12:33, 94.19s/it]

Loss: 0.2744, Train: 0.8210, Val: 0.7626, Test: 0.7468


 93%|█████████▎| 93/100 [2:13:22<11:35, 99.43s/it]

Loss: 0.2724, Train: 0.8221, Val: 0.7604, Test: 0.7427


 94%|█████████▍| 94/100 [2:15:13<10:18, 103.11s/it]

Loss: 0.2739, Train: 0.8242, Val: 0.7763, Test: 0.7517


 95%|█████████▌| 95/100 [2:17:11<08:57, 107.57s/it]

Loss: 0.2727, Train: 0.8132, Val: 0.7578, Test: 0.7383


 96%|█████████▌| 96/100 [2:19:17<07:32, 113.09s/it]

Loss: 0.2740, Train: 0.8199, Val: 0.7672, Test: 0.7431


 97%|█████████▋| 97/100 [2:21:11<05:40, 113.34s/it]

Loss: 0.2722, Train: 0.8160, Val: 0.7662, Test: 0.7434


 98%|█████████▊| 98/100 [2:23:09<03:49, 114.67s/it]

Loss: 0.2727, Train: 0.8214, Val: 0.7801, Test: 0.7495


 99%|█████████▉| 99/100 [2:25:02<01:54, 114.11s/it]

Loss: 0.2717, Train: 0.8200, Val: 0.7616, Test: 0.7380


100%|██████████| 100/100 [2:26:54<00:00, 88.15s/it] 

Loss: 0.2723, Train: 0.8208, Val: 0.7634, Test: 0.7424



