## Import Libraries

In [3]:
### Author: Yuvasri Raghavan (2022)

In [4]:
# argparse for user friendly parsing of arguments
import argparse

#machine learning python framework
import torch

#PyTorch library for graph networks
from torch_geometric.nn import Node2Vec

# OGB Dataloader to load Drug Interaction Network Data
from ogb.linkproppred import PygLinkPropPredDataset

## Node2Vec

### Saving Embeddings

In [2]:
#Creating a function to save embeddings in Node2Vec

def save_embedding_todevice(model):
    torch.save(model.embedding.weight.data.cpu(), 'embedding.pt')

### Load Data

In [2]:
#load OGB DDI Dataset
dataset = PygLinkPropPredDataset(name='ogbl-ddi')
data = dataset[0]

#Split the data based on a protein split
split_edge = dataset.get_edge_split()
idx = torch.randperm(split_edge['train']['edge'].size(0))
idx = idx[:split_edge['valid']['edge'].size(0)]
split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

### Node2Vec Embedding Model

In [3]:
def main():
    
    #Parser for arguments for Node2Vec
    
    parser = argparse.ArgumentParser(description='OGBL-DDI (Node2Vec)')
    parser.add_argument('-f')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--embedding_dim', type=int, default=128)
    parser.add_argument('--walk_length', type=int, default=40)
    parser.add_argument('--context_size', type=int, default=20)
    parser.add_argument('--walks_per_node', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--log_steps', type=int, default=1)
    args = parser.parse_args()

    #Use GPU if available
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    #load OGB DDI Dataset
    dataset = PygLinkPropPredDataset(name='ogbl-ddi')
    data = dataset[0]

    #Defining the Node@Vec Model
    model = Node2Vec(data.edge_index, args.embedding_dim, args.walk_length,
                     args.context_size, args.walks_per_node,
                     sparse=True).to(device)

    loader = model.loader(batch_size=args.batch_size, shuffle=True,
                          num_workers=4)
    
    # Using SparseAdam optimizer for parameterization
    optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=args.lr)

    #Train Node2Vec model
    model.train()
    for epoch in range(1, args.epochs + 1):
        for i, (pos_rw, neg_rw) in enumerate(loader):
            optimizer.zero_grad()
            loss = model.loss(pos_rw.to(device), neg_rw.to(device))
            loss.backward()
            optimizer.step()

            if (i + 1) % args.log_steps == 0:
                print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, '
                      f'Loss: {loss:.4f}')

            if (i + 1) % 100 == 0:  # Save model every 100 steps.
                save_embedding_todevice(model)
        save_embedding_todevice(model)


if __name__ == "__main__":
    main()

Epoch: 01, Step: 001/17, Loss: 10.0726
Epoch: 01, Step: 002/17, Loss: 9.9564
Epoch: 01, Step: 003/17, Loss: 9.7899
Epoch: 01, Step: 004/17, Loss: 9.6619
Epoch: 01, Step: 005/17, Loss: 9.5450
Epoch: 01, Step: 006/17, Loss: 9.3609
Epoch: 01, Step: 007/17, Loss: 9.2638
Epoch: 01, Step: 008/17, Loss: 9.1543
Epoch: 01, Step: 009/17, Loss: 8.9926
Epoch: 01, Step: 010/17, Loss: 8.8368
Epoch: 01, Step: 011/17, Loss: 8.7348
Epoch: 01, Step: 012/17, Loss: 8.6059
Epoch: 01, Step: 013/17, Loss: 8.4894
Epoch: 01, Step: 014/17, Loss: 8.3420
Epoch: 01, Step: 015/17, Loss: 8.2165
Epoch: 01, Step: 016/17, Loss: 8.0934
Epoch: 01, Step: 017/17, Loss: 7.9631
Epoch: 02, Step: 001/17, Loss: 7.8535
Epoch: 02, Step: 002/17, Loss: 7.7301
Epoch: 02, Step: 003/17, Loss: 7.6062
Epoch: 02, Step: 004/17, Loss: 7.4901
Epoch: 02, Step: 005/17, Loss: 7.3756
Epoch: 02, Step: 006/17, Loss: 7.2684
Epoch: 02, Step: 007/17, Loss: 7.1581
Epoch: 02, Step: 008/17, Loss: 7.0486
Epoch: 02, Step: 009/17, Loss: 6.9217
Epoch: 02, 

Epoch: 13, Step: 013/17, Loss: 1.3503
Epoch: 13, Step: 014/17, Loss: 1.3455
Epoch: 13, Step: 015/17, Loss: 1.3461
Epoch: 13, Step: 016/17, Loss: 1.3456
Epoch: 13, Step: 017/17, Loss: 1.3376
Epoch: 14, Step: 001/17, Loss: 1.3331
Epoch: 14, Step: 002/17, Loss: 1.3360
Epoch: 14, Step: 003/17, Loss: 1.3339
Epoch: 14, Step: 004/17, Loss: 1.3318
Epoch: 14, Step: 005/17, Loss: 1.3357
Epoch: 14, Step: 006/17, Loss: 1.3250
Epoch: 14, Step: 007/17, Loss: 1.3268
Epoch: 14, Step: 008/17, Loss: 1.3165
Epoch: 14, Step: 009/17, Loss: 1.3206
Epoch: 14, Step: 010/17, Loss: 1.3204
Epoch: 14, Step: 011/17, Loss: 1.3214
Epoch: 14, Step: 012/17, Loss: 1.3209
Epoch: 14, Step: 013/17, Loss: 1.3188
Epoch: 14, Step: 014/17, Loss: 1.3161
Epoch: 14, Step: 015/17, Loss: 1.3142
Epoch: 14, Step: 016/17, Loss: 1.3171
Epoch: 14, Step: 017/17, Loss: 1.3080
Epoch: 15, Step: 001/17, Loss: 1.3076
Epoch: 15, Step: 002/17, Loss: 1.3092
Epoch: 15, Step: 003/17, Loss: 1.3071
Epoch: 15, Step: 004/17, Loss: 1.3103
Epoch: 15, S

Epoch: 26, Step: 008/17, Loss: 1.2180
Epoch: 26, Step: 009/17, Loss: 1.2195
Epoch: 26, Step: 010/17, Loss: 1.2169
Epoch: 26, Step: 011/17, Loss: 1.2166
Epoch: 26, Step: 012/17, Loss: 1.2137
Epoch: 26, Step: 013/17, Loss: 1.2198
Epoch: 26, Step: 014/17, Loss: 1.2228
Epoch: 26, Step: 015/17, Loss: 1.2166
Epoch: 26, Step: 016/17, Loss: 1.2210
Epoch: 26, Step: 017/17, Loss: 1.2192
Epoch: 27, Step: 001/17, Loss: 1.2153
Epoch: 27, Step: 002/17, Loss: 1.2174
Epoch: 27, Step: 003/17, Loss: 1.2182
Epoch: 27, Step: 004/17, Loss: 1.2180
Epoch: 27, Step: 005/17, Loss: 1.2147
Epoch: 27, Step: 006/17, Loss: 1.2174
Epoch: 27, Step: 007/17, Loss: 1.2173
Epoch: 27, Step: 008/17, Loss: 1.2163
Epoch: 27, Step: 009/17, Loss: 1.2192
Epoch: 27, Step: 010/17, Loss: 1.2146
Epoch: 27, Step: 011/17, Loss: 1.2179
Epoch: 27, Step: 012/17, Loss: 1.2157
Epoch: 27, Step: 013/17, Loss: 1.2163
Epoch: 27, Step: 014/17, Loss: 1.2160
Epoch: 27, Step: 015/17, Loss: 1.2167
Epoch: 27, Step: 016/17, Loss: 1.2151
Epoch: 27, S

Epoch: 39, Step: 003/17, Loss: 1.2044
Epoch: 39, Step: 004/17, Loss: 1.2034
Epoch: 39, Step: 005/17, Loss: 1.2042
Epoch: 39, Step: 006/17, Loss: 1.2067
Epoch: 39, Step: 007/17, Loss: 1.2083
Epoch: 39, Step: 008/17, Loss: 1.2055
Epoch: 39, Step: 009/17, Loss: 1.2055
Epoch: 39, Step: 010/17, Loss: 1.2076
Epoch: 39, Step: 011/17, Loss: 1.2072
Epoch: 39, Step: 012/17, Loss: 1.2064
Epoch: 39, Step: 013/17, Loss: 1.2067
Epoch: 39, Step: 014/17, Loss: 1.2025
Epoch: 39, Step: 015/17, Loss: 1.2076
Epoch: 39, Step: 016/17, Loss: 1.2056
Epoch: 39, Step: 017/17, Loss: 1.2042
Epoch: 40, Step: 001/17, Loss: 1.2065
Epoch: 40, Step: 002/17, Loss: 1.2071
Epoch: 40, Step: 003/17, Loss: 1.2065
Epoch: 40, Step: 004/17, Loss: 1.2041
Epoch: 40, Step: 005/17, Loss: 1.2043
Epoch: 40, Step: 006/17, Loss: 1.2030
Epoch: 40, Step: 007/17, Loss: 1.2030
Epoch: 40, Step: 008/17, Loss: 1.2011
Epoch: 40, Step: 009/17, Loss: 1.2077
Epoch: 40, Step: 010/17, Loss: 1.2029
Epoch: 40, Step: 011/17, Loss: 1.2026
Epoch: 40, S

Epoch: 51, Step: 015/17, Loss: 1.2055
Epoch: 51, Step: 016/17, Loss: 1.2019
Epoch: 51, Step: 017/17, Loss: 1.2007
Epoch: 52, Step: 001/17, Loss: 1.2052
Epoch: 52, Step: 002/17, Loss: 1.2014
Epoch: 52, Step: 003/17, Loss: 1.2012
Epoch: 52, Step: 004/17, Loss: 1.1990
Epoch: 52, Step: 005/17, Loss: 1.2006
Epoch: 52, Step: 006/17, Loss: 1.2016
Epoch: 52, Step: 007/17, Loss: 1.1996
Epoch: 52, Step: 008/17, Loss: 1.2011
Epoch: 52, Step: 009/17, Loss: 1.2021
Epoch: 52, Step: 010/17, Loss: 1.2008
Epoch: 52, Step: 011/17, Loss: 1.2013
Epoch: 52, Step: 012/17, Loss: 1.1980
Epoch: 52, Step: 013/17, Loss: 1.1974
Epoch: 52, Step: 014/17, Loss: 1.2044
Epoch: 52, Step: 015/17, Loss: 1.2006
Epoch: 52, Step: 016/17, Loss: 1.2027
Epoch: 52, Step: 017/17, Loss: 1.2030
Epoch: 53, Step: 001/17, Loss: 1.2017
Epoch: 53, Step: 002/17, Loss: 1.1986
Epoch: 53, Step: 003/17, Loss: 1.2047
Epoch: 53, Step: 004/17, Loss: 1.2021
Epoch: 53, Step: 005/17, Loss: 1.2044
Epoch: 53, Step: 006/17, Loss: 1.2036
Epoch: 53, S

Epoch: 64, Step: 010/17, Loss: 1.2012
Epoch: 64, Step: 011/17, Loss: 1.2049
Epoch: 64, Step: 012/17, Loss: 1.1998
Epoch: 64, Step: 013/17, Loss: 1.2035
Epoch: 64, Step: 014/17, Loss: 1.2003
Epoch: 64, Step: 015/17, Loss: 1.2022
Epoch: 64, Step: 016/17, Loss: 1.2021
Epoch: 64, Step: 017/17, Loss: 1.2053
Epoch: 65, Step: 001/17, Loss: 1.1999
Epoch: 65, Step: 002/17, Loss: 1.2006
Epoch: 65, Step: 003/17, Loss: 1.2007
Epoch: 65, Step: 004/17, Loss: 1.2000
Epoch: 65, Step: 005/17, Loss: 1.2041
Epoch: 65, Step: 006/17, Loss: 1.1987
Epoch: 65, Step: 007/17, Loss: 1.2006
Epoch: 65, Step: 008/17, Loss: 1.2019
Epoch: 65, Step: 009/17, Loss: 1.2006
Epoch: 65, Step: 010/17, Loss: 1.1985
Epoch: 65, Step: 011/17, Loss: 1.1984
Epoch: 65, Step: 012/17, Loss: 1.2002
Epoch: 65, Step: 013/17, Loss: 1.2011
Epoch: 65, Step: 014/17, Loss: 1.2016
Epoch: 65, Step: 015/17, Loss: 1.2029
Epoch: 65, Step: 016/17, Loss: 1.2015
Epoch: 65, Step: 017/17, Loss: 1.2019
Epoch: 66, Step: 001/17, Loss: 1.2003
Epoch: 66, S

Epoch: 77, Step: 005/17, Loss: 1.2015
Epoch: 77, Step: 006/17, Loss: 1.1978
Epoch: 77, Step: 007/17, Loss: 1.2018
Epoch: 77, Step: 008/17, Loss: 1.1968
Epoch: 77, Step: 009/17, Loss: 1.2040
Epoch: 77, Step: 010/17, Loss: 1.2045
Epoch: 77, Step: 011/17, Loss: 1.2017
Epoch: 77, Step: 012/17, Loss: 1.2001
Epoch: 77, Step: 013/17, Loss: 1.1980
Epoch: 77, Step: 014/17, Loss: 1.1992
Epoch: 77, Step: 015/17, Loss: 1.2006
Epoch: 77, Step: 016/17, Loss: 1.2006
Epoch: 77, Step: 017/17, Loss: 1.2023
Epoch: 78, Step: 001/17, Loss: 1.1993
Epoch: 78, Step: 002/17, Loss: 1.1966
Epoch: 78, Step: 003/17, Loss: 1.1973
Epoch: 78, Step: 004/17, Loss: 1.2020
Epoch: 78, Step: 005/17, Loss: 1.2020
Epoch: 78, Step: 006/17, Loss: 1.1976
Epoch: 78, Step: 007/17, Loss: 1.2021
Epoch: 78, Step: 008/17, Loss: 1.2017
Epoch: 78, Step: 009/17, Loss: 1.1992
Epoch: 78, Step: 010/17, Loss: 1.2044
Epoch: 78, Step: 011/17, Loss: 1.2013
Epoch: 78, Step: 012/17, Loss: 1.1999
Epoch: 78, Step: 013/17, Loss: 1.1979
Epoch: 78, S

Epoch: 89, Step: 017/17, Loss: 1.2035
Epoch: 90, Step: 001/17, Loss: 1.2008
Epoch: 90, Step: 002/17, Loss: 1.2016
Epoch: 90, Step: 003/17, Loss: 1.1990
Epoch: 90, Step: 004/17, Loss: 1.2006
Epoch: 90, Step: 005/17, Loss: 1.1977
Epoch: 90, Step: 006/17, Loss: 1.1999
Epoch: 90, Step: 007/17, Loss: 1.2030
Epoch: 90, Step: 008/17, Loss: 1.2028
Epoch: 90, Step: 009/17, Loss: 1.2013
Epoch: 90, Step: 010/17, Loss: 1.2031
Epoch: 90, Step: 011/17, Loss: 1.2011
Epoch: 90, Step: 012/17, Loss: 1.1994
Epoch: 90, Step: 013/17, Loss: 1.1976
Epoch: 90, Step: 014/17, Loss: 1.2010
Epoch: 90, Step: 015/17, Loss: 1.1986
Epoch: 90, Step: 016/17, Loss: 1.2034
Epoch: 90, Step: 017/17, Loss: 1.2003
Epoch: 91, Step: 001/17, Loss: 1.1981
Epoch: 91, Step: 002/17, Loss: 1.2006
Epoch: 91, Step: 003/17, Loss: 1.2002
Epoch: 91, Step: 004/17, Loss: 1.1991
Epoch: 91, Step: 005/17, Loss: 1.2002
Epoch: 91, Step: 006/17, Loss: 1.1992
Epoch: 91, Step: 007/17, Loss: 1.2014
Epoch: 91, Step: 008/17, Loss: 1.1991
Epoch: 91, S

## Tracker Function to Record Run details

In [None]:
# Logger Code referred from: https://github.com/snap-stanford/ogb/blob/master/examples/linkproppred/ddi/logger.py

In [2]:
class Logger_Models(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Train: {result[:, 0].max():.2f}')
            print(f'Highest Valid: {result[:, 1].max():.2f}')
            print(f'  Final Train: {result[argmax, 0]:.2f}')
            print(f'   Final Test: {result[argmax, 2]:.2f}')
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = torch.tensor(best_results)

            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 1]
            print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 2]
            print(f'  Final Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 3]
            print(f'   Final Test: {r.mean():.2f} ± {r.std():.2f}')


## Multilayer Perceptron

In [5]:
import argparse

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

In [13]:
from ann_visualizer.visualize import ann_viz;

ann_viz(model, title="My first neural network")

In [6]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)


def train(predictor, x, edge_index, split_edge, optimizer, batch_size):
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):
        optimizer.zero_grad()

        edge = pos_train_edge[perm].t()

        pos_out = predictor(x[edge[0]], x[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(x[edge[0]], x[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()
        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(predictor, x, split_edge, evaluator, batch_size):
    predictor.eval()

    pos_train_edge = split_edge['eval_train']['edge'].to(x.device)
    pos_valid_edge = split_edge['valid']['edge'].to(x.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device)
    pos_test_edge = split_edge['test']['edge'].to(x.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(x.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 20, 30]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results


def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (MLP)')
    parser.add_argument('-f')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi')
    data = dataset[0]
    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    x = torch.load('embedding.pt', map_location='cpu').to(device)

    predictor = LinkPredictor(x.size(-1), args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-ddi')
    Logger_Models_Models = {
        'Hits@10': Logger_Models_Model(args.runs, args),
        'Hits@20': Logger_Models_Model(args.runs, args),
        'Hits@30': Logger_Models_Model(args.runs, args),
    }

    for run in range(args.runs):
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(predictor.parameters(), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(predictor, x, data.edge_index, split_edge, optimizer,
                         args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(predictor, x, split_edge, evaluator,
                               args.batch_size)
                for key, result in results.items():
                    Logger_Models_Models[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in Logger_Models_Models.keys():
            print(key)
            Logger_Models_Models[key].print_statistics(run)

    for key in Logger_Models_Models.keys():
        print(key)
        Logger_Models_Models[key].print_statistics()


if __name__ == "__main__":
    main()


Namespace(f='C:\\Users\\b04753yr\\AppData\\Roaming\\jupyter\\runtime\\kernel-30e704de-a8cc-40a4-8675-f15ec8e16fa1.json', device=0, log_steps=1, num_layers=3, hidden_channels=256, dropout=0.5, batch_size=65536, lr=0.01, epochs=200, eval_steps=5, runs=10)
Hits@10
Run: 01, Epoch: 05, Loss: 0.4626, Train: 9.03%, Valid: 7.82%, Test: 2.86%
Hits@20
Run: 01, Epoch: 05, Loss: 0.4626, Train: 11.33%, Valid: 9.77%, Test: 8.22%
Hits@30
Run: 01, Epoch: 05, Loss: 0.4626, Train: 14.72%, Valid: 12.75%, Test: 13.46%
---
Hits@10
Run: 01, Epoch: 10, Loss: 0.4281, Train: 9.99%, Valid: 8.56%, Test: 7.90%
Hits@20
Run: 01, Epoch: 10, Loss: 0.4281, Train: 15.64%, Valid: 13.61%, Test: 10.38%
Hits@30
Run: 01, Epoch: 10, Loss: 0.4281, Train: 20.27%, Valid: 17.76%, Test: 17.53%
---
Hits@10
Run: 01, Epoch: 15, Loss: 0.4102, Train: 13.14%, Valid: 11.15%, Test: 11.13%
Hits@20
Run: 01, Epoch: 15, Loss: 0.4102, Train: 19.22%, Valid: 16.61%, Test: 15.90%
Hits@30
Run: 01, Epoch: 15, Loss: 0.4102, Train: 22.94%, Valid: 19

Hits@10
Run: 01, Epoch: 160, Loss: 0.3728, Train: 15.83%, Valid: 13.27%, Test: 8.06%
Hits@20
Run: 01, Epoch: 160, Loss: 0.3728, Train: 23.93%, Valid: 20.44%, Test: 14.56%
Hits@30
Run: 01, Epoch: 160, Loss: 0.3728, Train: 29.84%, Valid: 25.82%, Test: 19.30%
---
Hits@10
Run: 01, Epoch: 165, Loss: 0.3722, Train: 15.81%, Valid: 13.24%, Test: 10.28%
Hits@20
Run: 01, Epoch: 165, Loss: 0.3722, Train: 20.96%, Valid: 17.78%, Test: 14.20%
Hits@30
Run: 01, Epoch: 165, Loss: 0.3722, Train: 26.89%, Valid: 23.17%, Test: 19.79%
---
Hits@10
Run: 01, Epoch: 170, Loss: 0.3712, Train: 15.64%, Valid: 13.11%, Test: 10.53%
Hits@20
Run: 01, Epoch: 170, Loss: 0.3712, Train: 24.32%, Valid: 20.75%, Test: 14.31%
Hits@30
Run: 01, Epoch: 170, Loss: 0.3712, Train: 26.66%, Valid: 22.88%, Test: 19.02%
---
Hits@10
Run: 01, Epoch: 175, Loss: 0.3729, Train: 8.22%, Valid: 6.78%, Test: 9.27%
Hits@20
Run: 01, Epoch: 175, Loss: 0.3729, Train: 19.55%, Valid: 16.42%, Test: 15.31%
Hits@30
Run: 01, Epoch: 175, Loss: 0.3729, Tra

Hits@10
Run: 02, Epoch: 115, Loss: 0.3733, Train: 14.22%, Valid: 11.92%, Test: 9.94%
Hits@20
Run: 02, Epoch: 115, Loss: 0.3733, Train: 21.89%, Valid: 18.60%, Test: 13.09%
Hits@30
Run: 02, Epoch: 115, Loss: 0.3733, Train: 23.92%, Valid: 20.49%, Test: 15.14%
---
Hits@10
Run: 02, Epoch: 120, Loss: 0.3723, Train: 15.90%, Valid: 13.36%, Test: 9.77%
Hits@20
Run: 02, Epoch: 120, Loss: 0.3723, Train: 21.83%, Valid: 18.58%, Test: 13.97%
Hits@30
Run: 02, Epoch: 120, Loss: 0.3723, Train: 27.30%, Valid: 23.55%, Test: 17.13%
---
Hits@10
Run: 02, Epoch: 125, Loss: 0.3726, Train: 14.46%, Valid: 12.09%, Test: 9.50%
Hits@20
Run: 02, Epoch: 125, Loss: 0.3726, Train: 21.27%, Valid: 17.94%, Test: 13.62%
Hits@30
Run: 02, Epoch: 125, Loss: 0.3726, Train: 25.23%, Valid: 21.52%, Test: 17.75%
---
Hits@10
Run: 02, Epoch: 130, Loss: 0.3725, Train: 16.56%, Valid: 13.94%, Test: 9.57%
Hits@20
Run: 02, Epoch: 130, Loss: 0.3725, Train: 20.11%, Valid: 17.06%, Test: 11.70%
Hits@30
Run: 02, Epoch: 130, Loss: 0.3725, Tra

Hits@10
Run: 03, Epoch: 70, Loss: 0.3777, Train: 16.93%, Valid: 14.34%, Test: 11.27%
Hits@20
Run: 03, Epoch: 70, Loss: 0.3777, Train: 24.79%, Valid: 21.31%, Test: 16.15%
Hits@30
Run: 03, Epoch: 70, Loss: 0.3777, Train: 30.10%, Valid: 26.23%, Test: 20.94%
---
Hits@10
Run: 03, Epoch: 75, Loss: 0.3769, Train: 15.57%, Valid: 13.11%, Test: 9.96%
Hits@20
Run: 03, Epoch: 75, Loss: 0.3769, Train: 24.67%, Valid: 21.26%, Test: 16.31%
Hits@30
Run: 03, Epoch: 75, Loss: 0.3769, Train: 30.95%, Valid: 27.00%, Test: 21.59%
---
Hits@10
Run: 03, Epoch: 80, Loss: 0.3773, Train: 15.84%, Valid: 13.32%, Test: 10.18%
Hits@20
Run: 03, Epoch: 80, Loss: 0.3773, Train: 24.18%, Valid: 20.77%, Test: 16.17%
Hits@30
Run: 03, Epoch: 80, Loss: 0.3773, Train: 28.03%, Valid: 24.22%, Test: 19.45%
---
Hits@10
Run: 03, Epoch: 85, Loss: 0.3774, Train: 14.26%, Valid: 11.97%, Test: 11.22%
Hits@20
Run: 03, Epoch: 85, Loss: 0.3774, Train: 22.78%, Valid: 19.46%, Test: 15.04%
Hits@30
Run: 03, Epoch: 85, Loss: 0.3774, Train: 26.94

Hits@10
Run: 04, Epoch: 25, Loss: 0.3964, Train: 12.78%, Valid: 10.82%, Test: 14.05%
Hits@20
Run: 04, Epoch: 25, Loss: 0.3964, Train: 21.86%, Valid: 18.82%, Test: 17.86%
Hits@30
Run: 04, Epoch: 25, Loss: 0.3964, Train: 26.93%, Valid: 23.45%, Test: 22.73%
---
Hits@10
Run: 04, Epoch: 30, Loss: 0.3932, Train: 14.04%, Valid: 11.84%, Test: 10.88%
Hits@20
Run: 04, Epoch: 30, Loss: 0.3932, Train: 20.52%, Valid: 17.69%, Test: 16.42%
Hits@30
Run: 04, Epoch: 30, Loss: 0.3932, Train: 23.91%, Valid: 20.74%, Test: 20.41%
---
Hits@10
Run: 04, Epoch: 35, Loss: 0.3896, Train: 12.22%, Valid: 10.33%, Test: 10.12%
Hits@20
Run: 04, Epoch: 35, Loss: 0.3896, Train: 19.36%, Valid: 16.62%, Test: 18.03%
Hits@30
Run: 04, Epoch: 35, Loss: 0.3896, Train: 23.44%, Valid: 20.26%, Test: 21.86%
---
Hits@10
Run: 04, Epoch: 40, Loss: 0.3870, Train: 13.79%, Valid: 11.60%, Test: 12.09%
Hits@20
Run: 04, Epoch: 40, Loss: 0.3870, Train: 21.39%, Valid: 18.38%, Test: 16.81%
Hits@30
Run: 04, Epoch: 40, Loss: 0.3870, Train: 25.2

Hits@10
Run: 04, Epoch: 185, Loss: 0.3727, Train: 14.85%, Valid: 12.38%, Test: 10.09%
Hits@20
Run: 04, Epoch: 185, Loss: 0.3727, Train: 20.57%, Valid: 17.30%, Test: 13.47%
Hits@30
Run: 04, Epoch: 185, Loss: 0.3727, Train: 25.16%, Valid: 21.48%, Test: 17.14%
---
Hits@10
Run: 04, Epoch: 190, Loss: 0.3718, Train: 18.58%, Valid: 15.67%, Test: 9.59%
Hits@20
Run: 04, Epoch: 190, Loss: 0.3718, Train: 23.01%, Valid: 19.58%, Test: 13.56%
Hits@30
Run: 04, Epoch: 190, Loss: 0.3718, Train: 26.83%, Valid: 23.06%, Test: 16.22%
---
Hits@10
Run: 04, Epoch: 195, Loss: 0.3731, Train: 15.84%, Valid: 13.18%, Test: 10.27%
Hits@20
Run: 04, Epoch: 195, Loss: 0.3731, Train: 21.34%, Valid: 18.00%, Test: 15.51%
Hits@30
Run: 04, Epoch: 195, Loss: 0.3731, Train: 25.30%, Valid: 21.61%, Test: 17.45%
---
Hits@10
Run: 04, Epoch: 200, Loss: 0.3723, Train: 13.19%, Valid: 10.98%, Test: 9.75%
Hits@20
Run: 04, Epoch: 200, Loss: 0.3723, Train: 21.36%, Valid: 18.06%, Test: 14.69%
Hits@30
Run: 04, Epoch: 200, Loss: 0.3723, T

Hits@10
Run: 05, Epoch: 140, Loss: 0.3724, Train: 15.76%, Valid: 13.07%, Test: 12.43%
Hits@20
Run: 05, Epoch: 140, Loss: 0.3724, Train: 22.98%, Valid: 19.57%, Test: 16.26%
Hits@30
Run: 05, Epoch: 140, Loss: 0.3724, Train: 27.23%, Valid: 23.43%, Test: 19.38%
---
Hits@10
Run: 05, Epoch: 145, Loss: 0.3717, Train: 17.50%, Valid: 14.83%, Test: 12.01%
Hits@20
Run: 05, Epoch: 145, Loss: 0.3717, Train: 23.58%, Valid: 20.20%, Test: 18.02%
Hits@30
Run: 05, Epoch: 145, Loss: 0.3717, Train: 27.32%, Valid: 23.66%, Test: 20.32%
---
Hits@10
Run: 05, Epoch: 150, Loss: 0.3731, Train: 13.05%, Valid: 10.95%, Test: 10.76%
Hits@20
Run: 05, Epoch: 150, Loss: 0.3731, Train: 20.36%, Valid: 17.16%, Test: 15.48%
Hits@30
Run: 05, Epoch: 150, Loss: 0.3731, Train: 26.59%, Valid: 22.81%, Test: 17.94%
---
Hits@10
Run: 05, Epoch: 155, Loss: 0.3727, Train: 15.12%, Valid: 12.62%, Test: 11.54%
Hits@20
Run: 05, Epoch: 155, Loss: 0.3727, Train: 22.81%, Valid: 19.52%, Test: 17.02%
Hits@30
Run: 05, Epoch: 155, Loss: 0.3727,

Hits@10
Run: 06, Epoch: 95, Loss: 0.3759, Train: 15.47%, Valid: 12.93%, Test: 11.95%
Hits@20
Run: 06, Epoch: 95, Loss: 0.3759, Train: 23.06%, Valid: 19.79%, Test: 17.16%
Hits@30
Run: 06, Epoch: 95, Loss: 0.3759, Train: 27.87%, Valid: 24.19%, Test: 19.96%
---
Hits@10
Run: 06, Epoch: 100, Loss: 0.3744, Train: 15.03%, Valid: 12.65%, Test: 12.18%
Hits@20
Run: 06, Epoch: 100, Loss: 0.3744, Train: 20.61%, Valid: 17.53%, Test: 17.90%
Hits@30
Run: 06, Epoch: 100, Loss: 0.3744, Train: 25.52%, Valid: 21.98%, Test: 20.97%
---
Hits@10
Run: 06, Epoch: 105, Loss: 0.3742, Train: 18.05%, Valid: 15.40%, Test: 14.03%
Hits@20
Run: 06, Epoch: 105, Loss: 0.3742, Train: 22.79%, Valid: 19.70%, Test: 18.02%
Hits@30
Run: 06, Epoch: 105, Loss: 0.3742, Train: 26.30%, Valid: 22.87%, Test: 23.10%
---
Hits@10
Run: 06, Epoch: 110, Loss: 0.3739, Train: 16.86%, Valid: 14.15%, Test: 14.61%
Hits@20
Run: 06, Epoch: 110, Loss: 0.3739, Train: 22.24%, Valid: 19.03%, Test: 16.26%
Hits@30
Run: 06, Epoch: 110, Loss: 0.3739, Tr

Hits@10
Run: 07, Epoch: 50, Loss: 0.3845, Train: 17.12%, Valid: 14.59%, Test: 11.68%
Hits@20
Run: 07, Epoch: 50, Loss: 0.3845, Train: 23.68%, Valid: 20.37%, Test: 17.60%
Hits@30
Run: 07, Epoch: 50, Loss: 0.3845, Train: 27.66%, Valid: 23.99%, Test: 19.85%
---
Hits@10
Run: 07, Epoch: 55, Loss: 0.3836, Train: 19.54%, Valid: 16.61%, Test: 13.66%
Hits@20
Run: 07, Epoch: 55, Loss: 0.3836, Train: 23.82%, Valid: 20.58%, Test: 16.83%
Hits@30
Run: 07, Epoch: 55, Loss: 0.3836, Train: 28.19%, Valid: 24.54%, Test: 24.15%
---
Hits@10
Run: 07, Epoch: 60, Loss: 0.3818, Train: 16.25%, Valid: 13.87%, Test: 9.79%
Hits@20
Run: 07, Epoch: 60, Loss: 0.3818, Train: 21.12%, Valid: 18.19%, Test: 14.18%
Hits@30
Run: 07, Epoch: 60, Loss: 0.3818, Train: 25.92%, Valid: 22.49%, Test: 20.50%
---
Hits@10
Run: 07, Epoch: 65, Loss: 0.3815, Train: 14.66%, Valid: 12.43%, Test: 10.45%
Hits@20
Run: 07, Epoch: 65, Loss: 0.3815, Train: 20.98%, Valid: 17.90%, Test: 16.93%
Hits@30
Run: 07, Epoch: 65, Loss: 0.3815, Train: 26.31

Hits@10
Run: 08, Epoch: 05, Loss: 0.4641, Train: 8.31%, Valid: 7.20%, Test: 2.51%
Hits@20
Run: 08, Epoch: 05, Loss: 0.4641, Train: 11.75%, Valid: 10.18%, Test: 7.90%
Hits@30
Run: 08, Epoch: 05, Loss: 0.4641, Train: 14.17%, Valid: 12.33%, Test: 11.20%
---
Hits@10
Run: 08, Epoch: 10, Loss: 0.4306, Train: 10.03%, Valid: 8.64%, Test: 8.80%
Hits@20
Run: 08, Epoch: 10, Loss: 0.4306, Train: 15.12%, Valid: 13.12%, Test: 12.77%
Hits@30
Run: 08, Epoch: 10, Loss: 0.4306, Train: 19.86%, Valid: 17.34%, Test: 17.40%
---
Hits@10
Run: 08, Epoch: 15, Loss: 0.4140, Train: 14.17%, Valid: 12.18%, Test: 10.85%
Hits@20
Run: 08, Epoch: 15, Loss: 0.4140, Train: 20.52%, Valid: 17.76%, Test: 16.13%
Hits@30
Run: 08, Epoch: 15, Loss: 0.4140, Train: 23.58%, Valid: 20.46%, Test: 19.65%
---
Hits@10
Run: 08, Epoch: 20, Loss: 0.4041, Train: 13.50%, Valid: 11.54%, Test: 9.10%
Hits@20
Run: 08, Epoch: 20, Loss: 0.4041, Train: 17.46%, Valid: 14.91%, Test: 13.07%
Hits@30
Run: 08, Epoch: 20, Loss: 0.4041, Train: 22.25%, Val

Hits@10
Run: 08, Epoch: 165, Loss: 0.3729, Train: 13.18%, Valid: 10.98%, Test: 9.48%
Hits@20
Run: 08, Epoch: 165, Loss: 0.3729, Train: 19.59%, Valid: 16.56%, Test: 13.37%
Hits@30
Run: 08, Epoch: 165, Loss: 0.3729, Train: 23.57%, Valid: 20.10%, Test: 16.10%
---
Hits@10
Run: 08, Epoch: 170, Loss: 0.3749, Train: 12.67%, Valid: 10.49%, Test: 9.73%
Hits@20
Run: 08, Epoch: 170, Loss: 0.3749, Train: 19.63%, Valid: 16.64%, Test: 12.62%
Hits@30
Run: 08, Epoch: 170, Loss: 0.3749, Train: 23.50%, Valid: 19.98%, Test: 14.74%
---
Hits@10
Run: 08, Epoch: 175, Loss: 0.3730, Train: 14.77%, Valid: 12.42%, Test: 11.13%
Hits@20
Run: 08, Epoch: 175, Loss: 0.3730, Train: 21.45%, Valid: 18.16%, Test: 13.13%
Hits@30
Run: 08, Epoch: 175, Loss: 0.3730, Train: 26.38%, Valid: 22.66%, Test: 18.18%
---
Hits@10
Run: 08, Epoch: 180, Loss: 0.3728, Train: 15.16%, Valid: 12.68%, Test: 9.52%
Hits@20
Run: 08, Epoch: 180, Loss: 0.3728, Train: 19.17%, Valid: 16.18%, Test: 12.75%
Hits@30
Run: 08, Epoch: 180, Loss: 0.3728, Tr

Hits@10
Run: 09, Epoch: 120, Loss: 0.3759, Train: 17.23%, Valid: 14.59%, Test: 10.73%
Hits@20
Run: 09, Epoch: 120, Loss: 0.3759, Train: 23.74%, Valid: 20.36%, Test: 14.39%
Hits@30
Run: 09, Epoch: 120, Loss: 0.3759, Train: 27.05%, Valid: 23.33%, Test: 17.98%
---
Hits@10
Run: 09, Epoch: 125, Loss: 0.3743, Train: 15.81%, Valid: 13.31%, Test: 9.56%
Hits@20
Run: 09, Epoch: 125, Loss: 0.3743, Train: 24.49%, Valid: 21.14%, Test: 14.37%
Hits@30
Run: 09, Epoch: 125, Loss: 0.3743, Train: 29.23%, Valid: 25.46%, Test: 17.17%
---
Hits@10
Run: 09, Epoch: 130, Loss: 0.3742, Train: 17.69%, Valid: 14.94%, Test: 10.34%
Hits@20
Run: 09, Epoch: 130, Loss: 0.3742, Train: 24.72%, Valid: 21.12%, Test: 14.82%
Hits@30
Run: 09, Epoch: 130, Loss: 0.3742, Train: 26.85%, Valid: 23.06%, Test: 18.19%
---
Hits@10
Run: 09, Epoch: 135, Loss: 0.3729, Train: 16.73%, Valid: 14.05%, Test: 10.22%
Hits@20
Run: 09, Epoch: 135, Loss: 0.3729, Train: 24.27%, Valid: 20.83%, Test: 17.45%
Hits@30
Run: 09, Epoch: 135, Loss: 0.3729, 

Hits@10
Run: 10, Epoch: 75, Loss: 0.3794, Train: 17.66%, Valid: 15.07%, Test: 12.80%
Hits@20
Run: 10, Epoch: 75, Loss: 0.3794, Train: 23.12%, Valid: 19.91%, Test: 17.94%
Hits@30
Run: 10, Epoch: 75, Loss: 0.3794, Train: 29.27%, Valid: 25.52%, Test: 21.25%
---
Hits@10
Run: 10, Epoch: 80, Loss: 0.3785, Train: 16.52%, Valid: 13.98%, Test: 12.65%
Hits@20
Run: 10, Epoch: 80, Loss: 0.3785, Train: 28.41%, Valid: 24.68%, Test: 17.03%
Hits@30
Run: 10, Epoch: 80, Loss: 0.3785, Train: 30.82%, Valid: 26.96%, Test: 22.14%
---
Hits@10
Run: 10, Epoch: 85, Loss: 0.3786, Train: 15.52%, Valid: 13.11%, Test: 12.30%
Hits@20
Run: 10, Epoch: 85, Loss: 0.3786, Train: 24.95%, Valid: 21.52%, Test: 15.69%
Hits@30
Run: 10, Epoch: 85, Loss: 0.3786, Train: 31.19%, Valid: 27.21%, Test: 20.27%
---
Hits@10
Run: 10, Epoch: 90, Loss: 0.3797, Train: 14.71%, Valid: 12.41%, Test: 13.87%
Hits@20
Run: 10, Epoch: 90, Loss: 0.3797, Train: 22.81%, Valid: 19.63%, Test: 19.10%
Hits@30
Run: 10, Epoch: 90, Loss: 0.3797, Train: 29.7

## Matrix Factorization

In [7]:
import argparse

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

In [8]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)


def train(predictor, x, edge_index, split_edge, optimizer, batch_size):
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):
        optimizer.zero_grad()

        edge = pos_train_edge[perm].t()

        pos_out = predictor(x[edge[0]], x[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(x[edge[0]], x[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()
        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(predictor, x, split_edge, evaluator, batch_size):
    predictor.eval()

    pos_train_edge = split_edge['eval_train']['edge'].to(x.device)
    pos_valid_edge = split_edge['valid']['edge'].to(x.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device)
    pos_test_edge = split_edge['test']['edge'].to(x.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(x.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 20, 30]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results


def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (MF)')
    parser.add_argument('-f')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi')
    data = dataset[0]
    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    emb = torch.nn.Embedding(data.num_nodes, args.hidden_channels).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-ddi')
    Logger_Models_Models = {
        'Hits@10': Logger_Models_Model(args.runs, args),
        'Hits@20': Logger_Models_Model(args.runs, args),
        'Hits@30': Logger_Models_Model(args.runs, args),
    }

    for run in range(args.runs):
        emb.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(emb.parameters()) + list(predictor.parameters()), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(predictor, emb.weight, data.edge_index, split_edge,
                         optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(predictor, emb.weight, split_edge, evaluator,
                               args.batch_size)
                for key, result in results.items():
                    Logger_Models_Models[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in Logger_Models_Models.keys():
            print(key)
            Logger_Models_Models[key].print_statistics(run)

    for key in Logger_Models_Models.keys():
        print(key)
        Logger_Models_Models[key].print_statistics()


if __name__ == "__main__":
    main()


Namespace(f='C:\\Users\\b04753yr\\AppData\\Roaming\\jupyter\\runtime\\kernel-30e704de-a8cc-40a4-8675-f15ec8e16fa1.json', device=0, log_steps=1, num_layers=3, hidden_channels=256, dropout=0.5, batch_size=65536, lr=0.01, epochs=200, eval_steps=5, runs=10)
Hits@10
Run: 01, Epoch: 05, Loss: 1.3863, Train: 0.02%, Valid: 0.01%, Test: 0.01%
Hits@20
Run: 01, Epoch: 05, Loss: 1.3863, Train: 0.04%, Valid: 0.03%, Test: 0.02%
Hits@30
Run: 01, Epoch: 05, Loss: 1.3863, Train: 0.05%, Valid: 0.04%, Test: 0.03%
---
Hits@10
Run: 01, Epoch: 10, Loss: 1.0024, Train: 0.68%, Valid: 0.61%, Test: 0.65%
Hits@20
Run: 01, Epoch: 10, Loss: 1.0024, Train: 1.16%, Valid: 1.01%, Test: 1.26%
Hits@30
Run: 01, Epoch: 10, Loss: 1.0024, Train: 1.47%, Valid: 1.32%, Test: 1.49%
---
Hits@10
Run: 01, Epoch: 15, Loss: 0.7196, Train: 2.61%, Valid: 2.24%, Test: 2.44%
Hits@20
Run: 01, Epoch: 15, Loss: 0.7196, Train: 3.30%, Valid: 2.91%, Test: 3.67%
Hits@30
Run: 01, Epoch: 15, Loss: 0.7196, Train: 4.07%, Valid: 3.58%, Test: 4.38%


Hits@10
Run: 01, Epoch: 160, Loss: 0.0948, Train: 53.38%, Valid: 12.26%, Test: 4.75%
Hits@20
Run: 01, Epoch: 160, Loss: 0.0948, Train: 60.75%, Valid: 16.45%, Test: 9.24%
Hits@30
Run: 01, Epoch: 160, Loss: 0.0948, Train: 75.59%, Valid: 28.59%, Test: 13.88%
---
Hits@10
Run: 01, Epoch: 165, Loss: 0.0946, Train: 48.11%, Valid: 9.35%, Test: 7.44%
Hits@20
Run: 01, Epoch: 165, Loss: 0.0946, Train: 64.66%, Valid: 18.23%, Test: 12.03%
Hits@30
Run: 01, Epoch: 165, Loss: 0.0946, Train: 75.46%, Valid: 27.47%, Test: 31.27%
---
Hits@10
Run: 01, Epoch: 170, Loss: 0.0933, Train: 37.55%, Valid: 5.38%, Test: 3.87%
Hits@20
Run: 01, Epoch: 170, Loss: 0.0933, Train: 63.98%, Valid: 17.23%, Test: 11.39%
Hits@30
Run: 01, Epoch: 170, Loss: 0.0933, Train: 75.81%, Valid: 27.16%, Test: 18.54%
---
Hits@10
Run: 01, Epoch: 175, Loss: 0.0909, Train: 46.96%, Valid: 8.14%, Test: 4.70%
Hits@20
Run: 01, Epoch: 175, Loss: 0.0909, Train: 70.01%, Valid: 20.89%, Test: 10.94%
Hits@30
Run: 01, Epoch: 175, Loss: 0.0909, Train: 

Hits@10
Run: 02, Epoch: 115, Loss: 0.0998, Train: 42.15%, Valid: 17.26%, Test: 7.13%
Hits@20
Run: 02, Epoch: 115, Loss: 0.0998, Train: 55.56%, Valid: 25.50%, Test: 18.76%
Hits@30
Run: 02, Epoch: 115, Loss: 0.0998, Train: 64.17%, Valid: 31.52%, Test: 28.14%
---
Hits@10
Run: 02, Epoch: 120, Loss: 0.0968, Train: 34.40%, Valid: 12.24%, Test: 7.50%
Hits@20
Run: 02, Epoch: 120, Loss: 0.0968, Train: 49.57%, Valid: 20.77%, Test: 22.76%
Hits@30
Run: 02, Epoch: 120, Loss: 0.0968, Train: 62.55%, Valid: 29.29%, Test: 29.27%
---
Hits@10
Run: 02, Epoch: 125, Loss: 0.0944, Train: 34.42%, Valid: 11.72%, Test: 4.28%
Hits@20
Run: 02, Epoch: 125, Loss: 0.0944, Train: 53.76%, Valid: 22.63%, Test: 14.18%
Hits@30
Run: 02, Epoch: 125, Loss: 0.0944, Train: 63.94%, Valid: 29.51%, Test: 29.39%
---
Hits@10
Run: 02, Epoch: 130, Loss: 0.0921, Train: 30.60%, Valid: 9.44%, Test: 6.73%
Hits@20
Run: 02, Epoch: 130, Loss: 0.0921, Train: 51.80%, Valid: 20.73%, Test: 19.45%
Hits@30
Run: 02, Epoch: 130, Loss: 0.0921, Trai

Hits@10
Run: 03, Epoch: 70, Loss: 0.1236, Train: 35.16%, Valid: 17.64%, Test: 3.35%
Hits@20
Run: 03, Epoch: 70, Loss: 0.1236, Train: 50.99%, Valid: 28.96%, Test: 12.65%
Hits@30
Run: 03, Epoch: 70, Loss: 0.1236, Train: 59.92%, Valid: 36.22%, Test: 19.54%
---
Hits@10
Run: 03, Epoch: 75, Loss: 0.1191, Train: 40.00%, Valid: 19.66%, Test: 6.01%
Hits@20
Run: 03, Epoch: 75, Loss: 0.1191, Train: 50.57%, Valid: 27.31%, Test: 15.68%
Hits@30
Run: 03, Epoch: 75, Loss: 0.1191, Train: 60.08%, Valid: 34.98%, Test: 25.84%
---
Hits@10
Run: 03, Epoch: 80, Loss: 0.1162, Train: 44.45%, Valid: 21.57%, Test: 5.62%
Hits@20
Run: 03, Epoch: 80, Loss: 0.1162, Train: 53.76%, Valid: 28.37%, Test: 13.07%
Hits@30
Run: 03, Epoch: 80, Loss: 0.1162, Train: 57.66%, Valid: 31.47%, Test: 21.70%
---
Hits@10
Run: 03, Epoch: 85, Loss: 0.1130, Train: 48.00%, Valid: 23.06%, Test: 6.01%
Hits@20
Run: 03, Epoch: 85, Loss: 0.1130, Train: 58.78%, Valid: 31.27%, Test: 13.39%
Hits@30
Run: 03, Epoch: 85, Loss: 0.1130, Train: 63.91%, 

Hits@10
Run: 04, Epoch: 25, Loss: 0.2694, Train: 12.44%, Valid: 9.99%, Test: 8.12%
Hits@20
Run: 04, Epoch: 25, Loss: 0.2694, Train: 22.24%, Valid: 18.33%, Test: 18.89%
Hits@30
Run: 04, Epoch: 25, Loss: 0.2694, Train: 30.77%, Valid: 25.91%, Test: 23.35%
---
Hits@10
Run: 04, Epoch: 30, Loss: 0.2238, Train: 17.11%, Valid: 13.28%, Test: 7.33%
Hits@20
Run: 04, Epoch: 30, Loss: 0.2238, Train: 28.80%, Valid: 22.93%, Test: 19.29%
Hits@30
Run: 04, Epoch: 30, Loss: 0.2238, Train: 36.88%, Valid: 29.82%, Test: 22.88%
---
Hits@10
Run: 04, Epoch: 35, Loss: 0.1939, Train: 25.12%, Valid: 18.71%, Test: 10.54%
Hits@20
Run: 04, Epoch: 35, Loss: 0.1939, Train: 33.93%, Valid: 25.94%, Test: 21.63%
Hits@30
Run: 04, Epoch: 35, Loss: 0.1939, Train: 46.83%, Valid: 36.93%, Test: 25.96%
---
Hits@10
Run: 04, Epoch: 40, Loss: 0.1762, Train: 24.79%, Valid: 16.93%, Test: 7.52%
Hits@20
Run: 04, Epoch: 40, Loss: 0.1762, Train: 37.32%, Valid: 26.90%, Test: 13.61%
Hits@30
Run: 04, Epoch: 40, Loss: 0.1762, Train: 42.01%, 

Hits@10
Run: 04, Epoch: 185, Loss: 0.0852, Train: 55.30%, Valid: 7.73%, Test: 6.98%
Hits@20
Run: 04, Epoch: 185, Loss: 0.0852, Train: 72.18%, Valid: 17.63%, Test: 16.01%
Hits@30
Run: 04, Epoch: 185, Loss: 0.0852, Train: 80.21%, Valid: 26.99%, Test: 21.74%
---
Hits@10
Run: 04, Epoch: 190, Loss: 0.0834, Train: 54.48%, Valid: 7.01%, Test: 7.38%
Hits@20
Run: 04, Epoch: 190, Loss: 0.0834, Train: 73.02%, Valid: 17.68%, Test: 13.75%
Hits@30
Run: 04, Epoch: 190, Loss: 0.0834, Train: 79.63%, Valid: 25.16%, Test: 32.60%
---
Hits@10
Run: 04, Epoch: 195, Loss: 0.0826, Train: 53.09%, Valid: 6.29%, Test: 10.85%
Hits@20
Run: 04, Epoch: 195, Loss: 0.0826, Train: 73.94%, Valid: 18.20%, Test: 18.54%
Hits@30
Run: 04, Epoch: 195, Loss: 0.0826, Train: 78.46%, Valid: 23.24%, Test: 36.97%
---
Hits@10
Run: 04, Epoch: 200, Loss: 0.0815, Train: 53.95%, Valid: 6.14%, Test: 11.25%
Hits@20
Run: 04, Epoch: 200, Loss: 0.0815, Train: 74.14%, Valid: 17.51%, Test: 20.91%
Hits@30
Run: 04, Epoch: 200, Loss: 0.0815, Train

Hits@10
Run: 05, Epoch: 140, Loss: 0.0894, Train: 31.01%, Valid: 8.07%, Test: 5.30%
Hits@20
Run: 05, Epoch: 140, Loss: 0.0894, Train: 53.15%, Valid: 18.97%, Test: 12.52%
Hits@30
Run: 05, Epoch: 140, Loss: 0.0894, Train: 67.95%, Valid: 28.72%, Test: 29.23%
---
Hits@10
Run: 05, Epoch: 145, Loss: 0.0878, Train: 38.53%, Valid: 11.43%, Test: 5.28%
Hits@20
Run: 05, Epoch: 145, Loss: 0.0878, Train: 57.28%, Valid: 21.46%, Test: 10.51%
Hits@30
Run: 05, Epoch: 145, Loss: 0.0878, Train: 72.88%, Valid: 32.25%, Test: 26.87%
---
Hits@10
Run: 05, Epoch: 150, Loss: 0.0856, Train: 28.18%, Valid: 6.38%, Test: 4.38%
Hits@20
Run: 05, Epoch: 150, Loss: 0.0856, Train: 50.40%, Valid: 16.53%, Test: 11.67%
Hits@30
Run: 05, Epoch: 150, Loss: 0.0856, Train: 69.02%, Valid: 28.36%, Test: 20.78%
---
Hits@10
Run: 05, Epoch: 155, Loss: 0.0852, Train: 36.52%, Valid: 10.16%, Test: 4.33%
Hits@20
Run: 05, Epoch: 155, Loss: 0.0852, Train: 58.66%, Valid: 21.57%, Test: 12.06%
Hits@30
Run: 05, Epoch: 155, Loss: 0.0852, Train

Hits@10
Run: 06, Epoch: 95, Loss: 0.1013, Train: 51.27%, Valid: 17.63%, Test: 5.55%
Hits@20
Run: 06, Epoch: 95, Loss: 0.1013, Train: 60.79%, Valid: 23.85%, Test: 14.86%
Hits@30
Run: 06, Epoch: 95, Loss: 0.1013, Train: 68.94%, Valid: 30.54%, Test: 21.50%
---
Hits@10
Run: 06, Epoch: 100, Loss: 0.0986, Train: 52.81%, Valid: 17.67%, Test: 6.10%
Hits@20
Run: 06, Epoch: 100, Loss: 0.0986, Train: 61.16%, Valid: 23.26%, Test: 17.36%
Hits@30
Run: 06, Epoch: 100, Loss: 0.0986, Train: 70.52%, Valid: 30.88%, Test: 25.31%
---
Hits@10
Run: 06, Epoch: 105, Loss: 0.0975, Train: 50.73%, Valid: 15.48%, Test: 7.11%
Hits@20
Run: 06, Epoch: 105, Loss: 0.0975, Train: 63.37%, Valid: 23.57%, Test: 15.70%
Hits@30
Run: 06, Epoch: 105, Loss: 0.0975, Train: 70.06%, Valid: 29.10%, Test: 26.96%
---
Hits@10
Run: 06, Epoch: 110, Loss: 0.0948, Train: 53.53%, Valid: 16.27%, Test: 6.98%
Hits@20
Run: 06, Epoch: 110, Loss: 0.0948, Train: 68.26%, Valid: 26.35%, Test: 18.53%
Hits@30
Run: 06, Epoch: 110, Loss: 0.0948, Train:

Hits@10
Run: 07, Epoch: 50, Loss: 0.1418, Train: 45.47%, Valid: 25.95%, Test: 2.98%
Hits@20
Run: 07, Epoch: 50, Loss: 0.1418, Train: 54.91%, Valid: 33.60%, Test: 8.37%
Hits@30
Run: 07, Epoch: 50, Loss: 0.1418, Train: 60.97%, Valid: 39.10%, Test: 10.75%
---
Hits@10
Run: 07, Epoch: 55, Loss: 0.1355, Train: 39.35%, Valid: 19.46%, Test: 1.61%
Hits@20
Run: 07, Epoch: 55, Loss: 0.1355, Train: 55.58%, Valid: 31.87%, Test: 7.45%
Hits@30
Run: 07, Epoch: 55, Loss: 0.1355, Train: 63.77%, Valid: 39.37%, Test: 13.52%
---
Hits@10
Run: 07, Epoch: 60, Loss: 0.1313, Train: 47.04%, Valid: 23.13%, Test: 2.26%
Hits@20
Run: 07, Epoch: 60, Loss: 0.1313, Train: 56.84%, Valid: 30.74%, Test: 9.46%
Hits@30
Run: 07, Epoch: 60, Loss: 0.1313, Train: 65.48%, Valid: 38.70%, Test: 19.25%
---
Hits@10
Run: 07, Epoch: 65, Loss: 0.1264, Train: 42.31%, Valid: 18.37%, Test: 1.82%
Hits@20
Run: 07, Epoch: 65, Loss: 0.1264, Train: 55.48%, Valid: 27.76%, Test: 6.74%
Hits@30
Run: 07, Epoch: 65, Loss: 0.1264, Train: 67.18%, Vali

Hits@10
Run: 08, Epoch: 05, Loss: 1.1602, Train: 0.12%, Valid: 0.12%, Test: 0.09%
Hits@20
Run: 08, Epoch: 05, Loss: 1.1602, Train: 0.19%, Valid: 0.20%, Test: 0.19%
Hits@30
Run: 08, Epoch: 05, Loss: 1.1602, Train: 0.27%, Valid: 0.27%, Test: 0.25%
---
Hits@10
Run: 08, Epoch: 10, Loss: 0.8390, Train: 1.49%, Valid: 1.33%, Test: 2.67%
Hits@20
Run: 08, Epoch: 10, Loss: 0.8390, Train: 2.30%, Valid: 2.03%, Test: 3.51%
Hits@30
Run: 08, Epoch: 10, Loss: 0.8390, Train: 2.62%, Valid: 2.32%, Test: 4.30%
---
Hits@10
Run: 08, Epoch: 15, Loss: 0.4829, Train: 14.17%, Valid: 12.29%, Test: 9.61%
Hits@20
Run: 08, Epoch: 15, Loss: 0.4829, Train: 20.05%, Valid: 17.57%, Test: 13.92%
Hits@30
Run: 08, Epoch: 15, Loss: 0.4829, Train: 22.62%, Valid: 20.03%, Test: 16.87%
---
Hits@10
Run: 08, Epoch: 20, Loss: 0.3295, Train: 13.67%, Valid: 11.42%, Test: 9.16%
Hits@20
Run: 08, Epoch: 20, Loss: 0.3295, Train: 21.93%, Valid: 18.47%, Test: 15.34%
Hits@30
Run: 08, Epoch: 20, Loss: 0.3295, Train: 25.81%, Valid: 21.94%, T

Hits@10
Run: 08, Epoch: 165, Loss: 0.0821, Train: 47.60%, Valid: 11.36%, Test: 11.38%
Hits@20
Run: 08, Epoch: 165, Loss: 0.0821, Train: 58.99%, Valid: 17.37%, Test: 26.95%
Hits@30
Run: 08, Epoch: 165, Loss: 0.0821, Train: 67.55%, Valid: 23.01%, Test: 33.94%
---
Hits@10
Run: 08, Epoch: 170, Loss: 0.0816, Train: 42.90%, Valid: 9.11%, Test: 6.82%
Hits@20
Run: 08, Epoch: 170, Loss: 0.0816, Train: 62.45%, Valid: 19.09%, Test: 21.82%
Hits@30
Run: 08, Epoch: 170, Loss: 0.0816, Train: 72.29%, Valid: 26.33%, Test: 30.91%
---
Hits@10
Run: 08, Epoch: 175, Loss: 0.0793, Train: 48.44%, Valid: 11.02%, Test: 10.39%
Hits@20
Run: 08, Epoch: 175, Loss: 0.0793, Train: 62.81%, Valid: 18.88%, Test: 24.59%
Hits@30
Run: 08, Epoch: 175, Loss: 0.0793, Train: 75.90%, Valid: 28.84%, Test: 34.15%
---
Hits@10
Run: 08, Epoch: 180, Loss: 0.0789, Train: 44.61%, Valid: 8.91%, Test: 9.97%
Hits@20
Run: 08, Epoch: 180, Loss: 0.0789, Train: 65.36%, Valid: 19.59%, Test: 18.03%
Hits@30
Run: 08, Epoch: 180, Loss: 0.0789, Tra

Hits@10
Run: 09, Epoch: 120, Loss: 0.0957, Train: 40.74%, Valid: 15.25%, Test: 4.71%
Hits@20
Run: 09, Epoch: 120, Loss: 0.0957, Train: 58.50%, Valid: 25.82%, Test: 10.97%
Hits@30
Run: 09, Epoch: 120, Loss: 0.0957, Train: 66.44%, Valid: 31.47%, Test: 19.46%
---
Hits@10
Run: 09, Epoch: 125, Loss: 0.0949, Train: 45.64%, Valid: 17.63%, Test: 5.03%
Hits@20
Run: 09, Epoch: 125, Loss: 0.0949, Train: 58.11%, Valid: 25.26%, Test: 11.67%
Hits@30
Run: 09, Epoch: 125, Loss: 0.0949, Train: 66.75%, Valid: 31.40%, Test: 24.33%
---
Hits@10
Run: 09, Epoch: 130, Loss: 0.0932, Train: 46.21%, Valid: 17.25%, Test: 5.68%
Hits@20
Run: 09, Epoch: 130, Loss: 0.0932, Train: 62.89%, Valid: 27.68%, Test: 15.06%
Hits@30
Run: 09, Epoch: 130, Loss: 0.0932, Train: 68.99%, Valid: 32.18%, Test: 24.42%
---
Hits@10
Run: 09, Epoch: 135, Loss: 0.0908, Train: 37.75%, Valid: 12.14%, Test: 3.48%
Hits@20
Run: 09, Epoch: 135, Loss: 0.0908, Train: 59.22%, Valid: 24.23%, Test: 13.03%
Hits@30
Run: 09, Epoch: 135, Loss: 0.0908, Tra

Hits@10
Run: 10, Epoch: 75, Loss: 0.1182, Train: 35.87%, Valid: 12.91%, Test: 4.06%
Hits@20
Run: 10, Epoch: 75, Loss: 0.1182, Train: 54.68%, Valid: 24.80%, Test: 18.79%
Hits@30
Run: 10, Epoch: 75, Loss: 0.1182, Train: 62.86%, Valid: 31.55%, Test: 25.35%
---
Hits@10
Run: 10, Epoch: 80, Loss: 0.1160, Train: 43.20%, Valid: 15.65%, Test: 6.51%
Hits@20
Run: 10, Epoch: 80, Loss: 0.1160, Train: 59.24%, Valid: 26.58%, Test: 18.98%
Hits@30
Run: 10, Epoch: 80, Loss: 0.1160, Train: 65.25%, Valid: 31.80%, Test: 26.74%
---
Hits@10
Run: 10, Epoch: 85, Loss: 0.1134, Train: 38.07%, Valid: 11.87%, Test: 5.27%
Hits@20
Run: 10, Epoch: 85, Loss: 0.1134, Train: 58.40%, Valid: 24.49%, Test: 15.29%
Hits@30
Run: 10, Epoch: 85, Loss: 0.1134, Train: 64.12%, Valid: 29.24%, Test: 21.56%
---
Hits@10
Run: 10, Epoch: 90, Loss: 0.1099, Train: 38.29%, Valid: 10.98%, Test: 9.26%
Hits@20
Run: 10, Epoch: 90, Loss: 0.1099, Train: 56.79%, Valid: 21.88%, Test: 15.50%
Hits@30
Run: 10, Epoch: 90, Loss: 0.1099, Train: 66.54%, 

## GNN

In [9]:
import argparse

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

In [10]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)


def train(model, predictor, x, adj_t, split_edge, optimizer, batch_size):

    row, col, _ = adj_t.coo()
    edge_index = torch.stack([col, row], dim=0)

    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):
        optimizer.zero_grad()

        h = model(x, adj_t)

        edge = pos_train_edge[perm].t()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(model, predictor, x, adj_t, split_edge, evaluator, batch_size):
    model.eval()
    predictor.eval()

    h = model(x, adj_t)

    pos_train_edge = split_edge['eval_train']['edge'].to(x.device)
    pos_valid_edge = split_edge['valid']['edge'].to(x.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device)
    pos_test_edge = split_edge['test']['edge'].to(x.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(x.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 20, 30]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results


def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)')
    parser.add_argument('-f')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_sage', action='store_true')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi',
                                     transform=T.ToSparseTensor())
    data = dataset[0]
    adj_t = data.adj_t.to(device)

    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    if args.use_sage:
        model = SAGE(args.hidden_channels, args.hidden_channels,
                     args.hidden_channels, args.num_layers,
                     args.dropout).to(device)
    else:
        model = GCN(args.hidden_channels, args.hidden_channels,
                    args.hidden_channels, args.num_layers,
                    args.dropout).to(device)

    emb = torch.nn.Embedding(data.adj_t.size(0),
                             args.hidden_channels).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-ddi')
    Logger_Models_Models = {
        'Hits@10': Logger_Models_Model(args.runs, args),
        'Hits@20': Logger_Models_Model(args.runs, args),
        'Hits@30': Logger_Models_Model(args.runs, args),
    }

    for run in range(args.runs):
        torch.nn.init.xavier_uniform_(emb.weight)
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(model.parameters()) + list(emb.parameters()) +
            list(predictor.parameters()), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, emb.weight, adj_t, split_edge,
                         optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, predictor, emb.weight, adj_t, split_edge,
                               evaluator, args.batch_size)
                for key, result in results.items():
                    Logger_Models_Models[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in Logger_Models_Models.keys():
            print(key)
            Logger_Models_Models[key].print_statistics(run)

    for key in Logger_Models_Models.keys():
        print(key)
        Logger_Models_Models[key].print_statistics()


if __name__ == "__main__":
    main()


Namespace(f='C:\\Users\\b04753yr\\AppData\\Roaming\\jupyter\\runtime\\kernel-30e704de-a8cc-40a4-8675-f15ec8e16fa1.json', device=0, log_steps=1, use_sage=False, num_layers=2, hidden_channels=256, dropout=0.5, batch_size=65536, lr=0.005, epochs=200, eval_steps=5, runs=10)
Hits@10
Run: 01, Epoch: 05, Loss: 0.6742, Train: 7.81%, Valid: 6.89%, Test: 3.20%
Hits@20
Run: 01, Epoch: 05, Loss: 0.6742, Train: 9.77%, Valid: 8.58%, Test: 7.13%
Hits@30
Run: 01, Epoch: 05, Loss: 0.6742, Train: 11.90%, Valid: 10.56%, Test: 10.05%
---
Hits@10
Run: 01, Epoch: 10, Loss: 0.5182, Train: 16.45%, Valid: 14.62%, Test: 5.63%
Hits@20
Run: 01, Epoch: 10, Loss: 0.5182, Train: 20.27%, Valid: 18.20%, Test: 8.22%
Hits@30
Run: 01, Epoch: 10, Loss: 0.5182, Train: 23.08%, Valid: 20.80%, Test: 11.30%
---
Hits@10
Run: 01, Epoch: 15, Loss: 0.4362, Train: 20.31%, Valid: 18.13%, Test: 18.38%
Hits@20
Run: 01, Epoch: 15, Loss: 0.4362, Train: 27.99%, Valid: 25.24%, Test: 21.39%
Hits@30
Run: 01, Epoch: 15, Loss: 0.4362, Train: 

Hits@10
Run: 01, Epoch: 160, Loss: 0.2058, Train: 32.34%, Valid: 27.50%, Test: 21.34%
Hits@20
Run: 01, Epoch: 160, Loss: 0.2058, Train: 47.36%, Valid: 40.68%, Test: 29.22%
Hits@30
Run: 01, Epoch: 160, Loss: 0.2058, Train: 56.72%, Valid: 49.27%, Test: 36.79%
---
Hits@10
Run: 01, Epoch: 165, Loss: 0.2022, Train: 40.15%, Valid: 34.28%, Test: 18.90%
Hits@20
Run: 01, Epoch: 165, Loss: 0.2022, Train: 53.94%, Valid: 46.56%, Test: 37.09%
Hits@30
Run: 01, Epoch: 165, Loss: 0.2022, Train: 62.37%, Valid: 54.20%, Test: 49.06%
---
Hits@10
Run: 01, Epoch: 170, Loss: 0.2019, Train: 35.75%, Valid: 30.42%, Test: 19.22%
Hits@20
Run: 01, Epoch: 170, Loss: 0.2019, Train: 49.39%, Valid: 42.39%, Test: 37.40%
Hits@30
Run: 01, Epoch: 170, Loss: 0.2019, Train: 57.52%, Valid: 49.69%, Test: 40.72%
---
Hits@10
Run: 01, Epoch: 175, Loss: 0.1969, Train: 49.53%, Valid: 42.20%, Test: 17.66%
Hits@20
Run: 01, Epoch: 175, Loss: 0.1969, Train: 59.60%, Valid: 51.59%, Test: 35.93%
Hits@30
Run: 01, Epoch: 175, Loss: 0.1969,

Hits@10
Run: 02, Epoch: 115, Loss: 0.2240, Train: 26.20%, Valid: 22.43%, Test: 8.46%
Hits@20
Run: 02, Epoch: 115, Loss: 0.2240, Train: 33.05%, Valid: 28.59%, Test: 21.14%
Hits@30
Run: 02, Epoch: 115, Loss: 0.2240, Train: 38.64%, Valid: 33.52%, Test: 26.77%
---
Hits@10
Run: 02, Epoch: 120, Loss: 0.2190, Train: 38.81%, Valid: 33.33%, Test: 16.81%
Hits@20
Run: 02, Epoch: 120, Loss: 0.2190, Train: 51.21%, Valid: 44.63%, Test: 31.58%
Hits@30
Run: 02, Epoch: 120, Loss: 0.2190, Train: 54.91%, Valid: 48.07%, Test: 42.54%
---
Hits@10
Run: 02, Epoch: 125, Loss: 0.2191, Train: 29.27%, Valid: 25.00%, Test: 17.07%
Hits@20
Run: 02, Epoch: 125, Loss: 0.2191, Train: 42.80%, Valid: 36.95%, Test: 35.18%
Hits@30
Run: 02, Epoch: 125, Loss: 0.2191, Train: 49.79%, Valid: 43.21%, Test: 40.97%
---
Hits@10
Run: 02, Epoch: 130, Loss: 0.2164, Train: 39.70%, Valid: 34.22%, Test: 12.25%
Hits@20
Run: 02, Epoch: 130, Loss: 0.2164, Train: 50.22%, Valid: 43.53%, Test: 29.94%
Hits@30
Run: 02, Epoch: 130, Loss: 0.2164, 

Hits@10
Run: 03, Epoch: 70, Loss: 0.2518, Train: 30.00%, Valid: 25.99%, Test: 6.30%
Hits@20
Run: 03, Epoch: 70, Loss: 0.2518, Train: 45.48%, Valid: 39.84%, Test: 19.63%
Hits@30
Run: 03, Epoch: 70, Loss: 0.2518, Train: 49.45%, Valid: 43.64%, Test: 30.07%
---
Hits@10
Run: 03, Epoch: 75, Loss: 0.2497, Train: 23.26%, Valid: 19.99%, Test: 10.09%
Hits@20
Run: 03, Epoch: 75, Loss: 0.2497, Train: 33.60%, Valid: 29.03%, Test: 18.87%
Hits@30
Run: 03, Epoch: 75, Loss: 0.2497, Train: 43.09%, Valid: 37.69%, Test: 28.81%
---
Hits@10
Run: 03, Epoch: 80, Loss: 0.2440, Train: 39.06%, Valid: 33.93%, Test: 17.61%
Hits@20
Run: 03, Epoch: 80, Loss: 0.2440, Train: 47.37%, Valid: 41.60%, Test: 30.80%
Hits@30
Run: 03, Epoch: 80, Loss: 0.2440, Train: 54.75%, Valid: 48.54%, Test: 39.37%
---
Hits@10
Run: 03, Epoch: 85, Loss: 0.2422, Train: 29.70%, Valid: 25.65%, Test: 17.94%
Hits@20
Run: 03, Epoch: 85, Loss: 0.2422, Train: 38.20%, Valid: 33.15%, Test: 26.86%
Hits@30
Run: 03, Epoch: 85, Loss: 0.2422, Train: 47.72

Hits@10
Run: 04, Epoch: 25, Loss: 0.3542, Train: 23.83%, Valid: 21.23%, Test: 10.89%
Hits@20
Run: 04, Epoch: 25, Loss: 0.3542, Train: 32.17%, Valid: 28.68%, Test: 19.57%
Hits@30
Run: 04, Epoch: 25, Loss: 0.3542, Train: 37.04%, Valid: 33.35%, Test: 24.60%
---
Hits@10
Run: 04, Epoch: 30, Loss: 0.3305, Train: 21.35%, Valid: 18.94%, Test: 6.29%
Hits@20
Run: 04, Epoch: 30, Loss: 0.3305, Train: 26.84%, Valid: 23.92%, Test: 14.24%
Hits@30
Run: 04, Epoch: 30, Loss: 0.3305, Train: 32.09%, Valid: 28.66%, Test: 20.67%
---
Hits@10
Run: 04, Epoch: 35, Loss: 0.3146, Train: 20.45%, Valid: 17.81%, Test: 7.06%
Hits@20
Run: 04, Epoch: 35, Loss: 0.3146, Train: 30.20%, Valid: 26.69%, Test: 14.84%
Hits@30
Run: 04, Epoch: 35, Loss: 0.3146, Train: 36.34%, Valid: 32.37%, Test: 21.45%
---
Hits@10
Run: 04, Epoch: 40, Loss: 0.3022, Train: 25.31%, Valid: 22.32%, Test: 6.37%
Hits@20
Run: 04, Epoch: 40, Loss: 0.3022, Train: 29.54%, Valid: 26.06%, Test: 16.69%
Hits@30
Run: 04, Epoch: 40, Loss: 0.3022, Train: 38.03%,

Hits@10
Run: 04, Epoch: 185, Loss: 0.1920, Train: 50.35%, Valid: 43.18%, Test: 14.30%
Hits@20
Run: 04, Epoch: 185, Loss: 0.1920, Train: 62.53%, Valid: 54.22%, Test: 30.07%
Hits@30
Run: 04, Epoch: 185, Loss: 0.1920, Train: 65.61%, Valid: 57.09%, Test: 38.33%
---
Hits@10
Run: 04, Epoch: 190, Loss: 0.1902, Train: 42.92%, Valid: 36.36%, Test: 18.87%
Hits@20
Run: 04, Epoch: 190, Loss: 0.1902, Train: 66.51%, Valid: 57.85%, Test: 29.76%
Hits@30
Run: 04, Epoch: 190, Loss: 0.1902, Train: 68.63%, Valid: 59.84%, Test: 40.93%
---
Hits@10
Run: 04, Epoch: 195, Loss: 0.1906, Train: 44.97%, Valid: 37.95%, Test: 25.44%
Hits@20
Run: 04, Epoch: 195, Loss: 0.1906, Train: 62.01%, Valid: 53.48%, Test: 35.62%
Hits@30
Run: 04, Epoch: 195, Loss: 0.1906, Train: 66.51%, Valid: 57.69%, Test: 48.94%
---
Hits@10
Run: 04, Epoch: 200, Loss: 0.1901, Train: 36.85%, Valid: 30.87%, Test: 17.57%
Hits@20
Run: 04, Epoch: 200, Loss: 0.1901, Train: 64.47%, Valid: 55.70%, Test: 28.33%
Hits@30
Run: 04, Epoch: 200, Loss: 0.1901,

Hits@10
Run: 05, Epoch: 140, Loss: 0.2088, Train: 27.81%, Valid: 23.55%, Test: 36.97%
Hits@20
Run: 05, Epoch: 140, Loss: 0.2088, Train: 44.92%, Valid: 38.36%, Test: 47.96%
Hits@30
Run: 05, Epoch: 140, Loss: 0.2088, Train: 56.06%, Valid: 48.42%, Test: 53.45%
---
Hits@10
Run: 05, Epoch: 145, Loss: 0.2080, Train: 30.41%, Valid: 25.89%, Test: 22.65%
Hits@20
Run: 05, Epoch: 145, Loss: 0.2080, Train: 45.48%, Valid: 39.10%, Test: 31.25%
Hits@30
Run: 05, Epoch: 145, Loss: 0.2080, Train: 51.79%, Valid: 44.89%, Test: 43.08%
---
Hits@10
Run: 05, Epoch: 150, Loss: 0.2050, Train: 39.44%, Valid: 33.61%, Test: 23.56%
Hits@20
Run: 05, Epoch: 150, Loss: 0.2050, Train: 54.65%, Valid: 47.31%, Test: 32.50%
Hits@30
Run: 05, Epoch: 150, Loss: 0.2050, Train: 60.57%, Valid: 52.79%, Test: 48.32%
---
Hits@10
Run: 05, Epoch: 155, Loss: 0.2079, Train: 33.32%, Valid: 28.31%, Test: 24.53%
Hits@20
Run: 05, Epoch: 155, Loss: 0.2079, Train: 46.48%, Valid: 39.98%, Test: 30.84%
Hits@30
Run: 05, Epoch: 155, Loss: 0.2079,

Hits@10
Run: 06, Epoch: 95, Loss: 0.2320, Train: 32.21%, Valid: 27.48%, Test: 18.25%
Hits@20
Run: 06, Epoch: 95, Loss: 0.2320, Train: 48.82%, Valid: 42.50%, Test: 31.90%
Hits@30
Run: 06, Epoch: 95, Loss: 0.2320, Train: 52.42%, Valid: 45.94%, Test: 36.76%
---
Hits@10
Run: 06, Epoch: 100, Loss: 0.2282, Train: 24.15%, Valid: 20.68%, Test: 18.14%
Hits@20
Run: 06, Epoch: 100, Loss: 0.2282, Train: 38.09%, Valid: 32.85%, Test: 25.41%
Hits@30
Run: 06, Epoch: 100, Loss: 0.2282, Train: 44.00%, Valid: 38.23%, Test: 31.78%
---
Hits@10
Run: 06, Epoch: 105, Loss: 0.2260, Train: 20.10%, Valid: 17.10%, Test: 20.92%
Hits@20
Run: 06, Epoch: 105, Loss: 0.2260, Train: 37.68%, Valid: 32.50%, Test: 30.78%
Hits@30
Run: 06, Epoch: 105, Loss: 0.2260, Train: 45.81%, Valid: 39.70%, Test: 35.07%
---
Hits@10
Run: 06, Epoch: 110, Loss: 0.2233, Train: 27.61%, Valid: 23.70%, Test: 23.67%
Hits@20
Run: 06, Epoch: 110, Loss: 0.2233, Train: 46.17%, Valid: 40.12%, Test: 32.13%
Hits@30
Run: 06, Epoch: 110, Loss: 0.2233, Tr

Hits@10
Run: 07, Epoch: 50, Loss: 0.2820, Train: 18.69%, Valid: 16.24%, Test: 5.46%
Hits@20
Run: 07, Epoch: 50, Loss: 0.2820, Train: 27.80%, Valid: 24.47%, Test: 13.71%
Hits@30
Run: 07, Epoch: 50, Loss: 0.2820, Train: 34.72%, Valid: 30.70%, Test: 20.38%
---
Hits@10
Run: 07, Epoch: 55, Loss: 0.2755, Train: 8.66%, Valid: 7.35%, Test: 7.98%
Hits@20
Run: 07, Epoch: 55, Loss: 0.2755, Train: 26.08%, Valid: 22.55%, Test: 19.88%
Hits@30
Run: 07, Epoch: 55, Loss: 0.2755, Train: 32.84%, Valid: 28.86%, Test: 27.93%
---
Hits@10
Run: 07, Epoch: 60, Loss: 0.2687, Train: 15.42%, Valid: 13.03%, Test: 5.00%
Hits@20
Run: 07, Epoch: 60, Loss: 0.2687, Train: 24.23%, Valid: 21.02%, Test: 12.95%
Hits@30
Run: 07, Epoch: 60, Loss: 0.2687, Train: 30.17%, Valid: 26.50%, Test: 20.89%
---
Hits@10
Run: 07, Epoch: 65, Loss: 0.2631, Train: 22.11%, Valid: 19.17%, Test: 5.81%
Hits@20
Run: 07, Epoch: 65, Loss: 0.2631, Train: 36.74%, Valid: 32.14%, Test: 15.52%
Hits@30
Run: 07, Epoch: 65, Loss: 0.2631, Train: 39.58%, Va

Hits@10
Run: 08, Epoch: 05, Loss: 0.6589, Train: 4.74%, Valid: 4.09%, Test: 0.39%
Hits@20
Run: 08, Epoch: 05, Loss: 0.6589, Train: 11.32%, Valid: 9.82%, Test: 2.77%
Hits@30
Run: 08, Epoch: 05, Loss: 0.6589, Train: 15.71%, Valid: 14.17%, Test: 4.63%
---
Hits@10
Run: 08, Epoch: 10, Loss: 0.5112, Train: 14.76%, Valid: 13.13%, Test: 10.73%
Hits@20
Run: 08, Epoch: 10, Loss: 0.5112, Train: 18.53%, Valid: 16.55%, Test: 15.99%
Hits@30
Run: 08, Epoch: 10, Loss: 0.5112, Train: 23.69%, Valid: 21.39%, Test: 18.58%
---
Hits@10
Run: 08, Epoch: 15, Loss: 0.4368, Train: 25.64%, Valid: 23.22%, Test: 14.03%
Hits@20
Run: 08, Epoch: 15, Loss: 0.4368, Train: 28.99%, Valid: 26.46%, Test: 20.59%
Hits@30
Run: 08, Epoch: 15, Loss: 0.4368, Train: 32.02%, Valid: 29.37%, Test: 26.75%
---
Hits@10
Run: 08, Epoch: 20, Loss: 0.3875, Train: 18.12%, Valid: 15.88%, Test: 9.77%
Hits@20
Run: 08, Epoch: 20, Loss: 0.3875, Train: 23.38%, Valid: 20.80%, Test: 17.52%
Hits@30
Run: 08, Epoch: 20, Loss: 0.3875, Train: 30.54%, Val

Hits@10
Run: 08, Epoch: 165, Loss: 0.2031, Train: 44.52%, Valid: 38.34%, Test: 30.74%
Hits@20
Run: 08, Epoch: 165, Loss: 0.2031, Train: 58.18%, Valid: 50.48%, Test: 42.67%
Hits@30
Run: 08, Epoch: 165, Loss: 0.2031, Train: 64.45%, Valid: 56.37%, Test: 48.66%
---
Hits@10
Run: 08, Epoch: 170, Loss: 0.1980, Train: 43.05%, Valid: 36.82%, Test: 30.31%
Hits@20
Run: 08, Epoch: 170, Loss: 0.1980, Train: 52.85%, Valid: 45.41%, Test: 36.81%
Hits@30
Run: 08, Epoch: 170, Loss: 0.1980, Train: 60.84%, Valid: 52.78%, Test: 44.14%
---
Hits@10
Run: 08, Epoch: 175, Loss: 0.1968, Train: 45.99%, Valid: 39.18%, Test: 20.55%
Hits@20
Run: 08, Epoch: 175, Loss: 0.1968, Train: 62.02%, Valid: 53.82%, Test: 32.90%
Hits@30
Run: 08, Epoch: 175, Loss: 0.1968, Train: 66.79%, Valid: 58.34%, Test: 43.03%
---
Hits@10
Run: 08, Epoch: 180, Loss: 0.1958, Train: 49.00%, Valid: 41.75%, Test: 30.00%
Hits@20
Run: 08, Epoch: 180, Loss: 0.1958, Train: 61.15%, Valid: 52.77%, Test: 51.63%
Hits@30
Run: 08, Epoch: 180, Loss: 0.1958,

Hits@10
Run: 09, Epoch: 120, Loss: 0.2212, Train: 29.39%, Valid: 25.09%, Test: 24.63%
Hits@20
Run: 09, Epoch: 120, Loss: 0.2212, Train: 36.91%, Valid: 31.85%, Test: 34.34%
Hits@30
Run: 09, Epoch: 120, Loss: 0.2212, Train: 49.06%, Valid: 42.55%, Test: 41.92%
---
Hits@10
Run: 09, Epoch: 125, Loss: 0.2174, Train: 32.88%, Valid: 28.28%, Test: 21.64%
Hits@20
Run: 09, Epoch: 125, Loss: 0.2174, Train: 44.75%, Valid: 38.71%, Test: 31.28%
Hits@30
Run: 09, Epoch: 125, Loss: 0.2174, Train: 48.98%, Valid: 42.53%, Test: 37.93%
---
Hits@10
Run: 09, Epoch: 130, Loss: 0.2162, Train: 29.98%, Valid: 25.70%, Test: 26.29%
Hits@20
Run: 09, Epoch: 130, Loss: 0.2162, Train: 46.69%, Valid: 40.18%, Test: 38.23%
Hits@30
Run: 09, Epoch: 130, Loss: 0.2162, Train: 53.82%, Valid: 46.76%, Test: 46.73%
---
Hits@10
Run: 09, Epoch: 135, Loss: 0.2134, Train: 33.05%, Valid: 28.28%, Test: 25.96%
Hits@20
Run: 09, Epoch: 135, Loss: 0.2134, Train: 38.28%, Valid: 32.80%, Test: 31.00%
Hits@30
Run: 09, Epoch: 135, Loss: 0.2134,

Hits@10
Run: 10, Epoch: 75, Loss: 0.2502, Train: 22.24%, Valid: 19.20%, Test: 5.09%
Hits@20
Run: 10, Epoch: 75, Loss: 0.2502, Train: 32.63%, Valid: 28.48%, Test: 12.00%
Hits@30
Run: 10, Epoch: 75, Loss: 0.2502, Train: 40.35%, Valid: 35.42%, Test: 18.38%
---
Hits@10
Run: 10, Epoch: 80, Loss: 0.2482, Train: 25.09%, Valid: 21.69%, Test: 9.08%
Hits@20
Run: 10, Epoch: 80, Loss: 0.2482, Train: 35.52%, Valid: 30.78%, Test: 19.38%
Hits@30
Run: 10, Epoch: 80, Loss: 0.2482, Train: 41.12%, Valid: 35.99%, Test: 28.55%
---
Hits@10
Run: 10, Epoch: 85, Loss: 0.2417, Train: 19.71%, Valid: 16.90%, Test: 5.40%
Hits@20
Run: 10, Epoch: 85, Loss: 0.2417, Train: 37.94%, Valid: 32.91%, Test: 10.70%
Hits@30
Run: 10, Epoch: 85, Loss: 0.2417, Train: 42.73%, Valid: 37.25%, Test: 22.58%
---
Hits@10
Run: 10, Epoch: 90, Loss: 0.2390, Train: 15.94%, Valid: 13.47%, Test: 6.38%
Hits@20
Run: 10, Epoch: 90, Loss: 0.2390, Train: 27.70%, Valid: 23.90%, Test: 13.28%
Hits@30
Run: 10, Epoch: 90, Loss: 0.2390, Train: 35.38%, 

## Graph Sage

In [11]:
import argparse

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

In [12]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)



def train(model, predictor, x, adj_t, split_edge, optimizer, batch_size):

    row, col, _ = adj_t.coo()
    edge_index = torch.stack([col, row], dim=0)

    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):
        optimizer.zero_grad()

        h = model(x, adj_t)

        edge = pos_train_edge[perm].t()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(model, predictor, x, adj_t, split_edge, evaluator, batch_size):
    model.eval()
    predictor.eval()

    h = model(x, adj_t)

    pos_train_edge = split_edge['eval_train']['edge'].to(x.device)
    pos_valid_edge = split_edge['valid']['edge'].to(x.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device)
    pos_test_edge = split_edge['test']['edge'].to(x.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(x.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 20, 30]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results


def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)')
    parser.add_argument('-f')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_sage', default = True)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi',
                                     transform=T.ToSparseTensor())
    data = dataset[0]
    adj_t = data.adj_t.to(device)

    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    if args.use_sage:
        model = SAGE(args.hidden_channels, args.hidden_channels,
                     args.hidden_channels, args.num_layers,
                     args.dropout).to(device)
    else:
        model = GCN(args.hidden_channels, args.hidden_channels,
                    args.hidden_channels, args.num_layers,
                    args.dropout).to(device)

    emb = torch.nn.Embedding(data.adj_t.size(0),
                             args.hidden_channels).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-ddi')
    Logger_Models_Models = {
        'Hits@10': Logger_Models_Model(args.runs, args),
        'Hits@20': Logger_Models_Model(args.runs, args),
        'Hits@30': Logger_Models_Model(args.runs, args),
    }

    for run in range(args.runs):
        torch.nn.init.xavier_uniform_(emb.weight)
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(model.parameters()) + list(emb.parameters()) +
            list(predictor.parameters()), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, emb.weight, adj_t, split_edge,
                         optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, predictor, emb.weight, adj_t, split_edge,
                               evaluator, args.batch_size)
                for key, result in results.items():
                    Logger_Models_Models[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in Logger_Models_Models.keys():
            print(key)
            Logger_Models_Models[key].print_statistics(run)

    for key in Logger_Models_Models.keys():
        print(key)
        Logger_Models_Models[key].print_statistics()


if __name__ == "__main__":
    main()


Namespace(f='C:\\Users\\b04753yr\\AppData\\Roaming\\jupyter\\runtime\\kernel-30e704de-a8cc-40a4-8675-f15ec8e16fa1.json', device=0, log_steps=1, use_sage=True, num_layers=2, hidden_channels=256, dropout=0.5, batch_size=65536, lr=0.005, epochs=200, eval_steps=5, runs=10)
Hits@10
Run: 01, Epoch: 05, Loss: 0.5646, Train: 10.91%, Valid: 9.69%, Test: 3.30%
Hits@20
Run: 01, Epoch: 05, Loss: 0.5646, Train: 17.74%, Valid: 15.89%, Test: 8.36%
Hits@30
Run: 01, Epoch: 05, Loss: 0.5646, Train: 22.12%, Valid: 20.14%, Test: 11.91%
---
Hits@10
Run: 01, Epoch: 10, Loss: 0.4070, Train: 19.12%, Valid: 16.52%, Test: 7.65%
Hits@20
Run: 01, Epoch: 10, Loss: 0.4070, Train: 22.87%, Valid: 20.05%, Test: 11.71%
Hits@30
Run: 01, Epoch: 10, Loss: 0.4070, Train: 28.42%, Valid: 25.33%, Test: 18.43%
---
Hits@10
Run: 01, Epoch: 15, Loss: 0.3325, Train: 33.58%, Valid: 29.27%, Test: 9.86%
Hits@20
Run: 01, Epoch: 15, Loss: 0.3325, Train: 39.61%, Valid: 34.80%, Test: 15.72%
Hits@30
Run: 01, Epoch: 15, Loss: 0.3325, Train

Hits@10
Run: 01, Epoch: 160, Loss: 0.1667, Train: 66.48%, Valid: 55.77%, Test: 48.98%
Hits@20
Run: 01, Epoch: 160, Loss: 0.1667, Train: 72.57%, Valid: 61.91%, Test: 59.56%
Hits@30
Run: 01, Epoch: 160, Loss: 0.1667, Train: 76.14%, Valid: 65.58%, Test: 65.23%
---
Hits@10
Run: 01, Epoch: 165, Loss: 0.1665, Train: 73.28%, Valid: 63.07%, Test: 45.18%
Hits@20
Run: 01, Epoch: 165, Loss: 0.1665, Train: 77.57%, Valid: 67.48%, Test: 61.49%
Hits@30
Run: 01, Epoch: 165, Loss: 0.1665, Train: 79.38%, Valid: 69.22%, Test: 73.21%
---
Hits@10
Run: 01, Epoch: 170, Loss: 0.1637, Train: 68.70%, Valid: 57.96%, Test: 36.45%
Hits@20
Run: 01, Epoch: 170, Loss: 0.1637, Train: 74.37%, Valid: 63.85%, Test: 53.82%
Hits@30
Run: 01, Epoch: 170, Loss: 0.1637, Train: 77.68%, Valid: 67.34%, Test: 63.34%
---
Hits@10
Run: 01, Epoch: 175, Loss: 0.1629, Train: 55.89%, Valid: 45.40%, Test: 36.92%
Hits@20
Run: 01, Epoch: 175, Loss: 0.1629, Train: 74.51%, Valid: 63.92%, Test: 56.84%
Hits@30
Run: 01, Epoch: 175, Loss: 0.1629,

Hits@10
Run: 02, Epoch: 115, Loss: 0.1717, Train: 66.60%, Valid: 56.57%, Test: 46.29%
Hits@20
Run: 02, Epoch: 115, Loss: 0.1717, Train: 74.70%, Valid: 64.69%, Test: 67.60%
Hits@30
Run: 02, Epoch: 115, Loss: 0.1717, Train: 76.49%, Valid: 66.46%, Test: 76.66%
---
Hits@10
Run: 02, Epoch: 120, Loss: 0.1730, Train: 71.97%, Valid: 61.92%, Test: 37.96%
Hits@20
Run: 02, Epoch: 120, Loss: 0.1730, Train: 76.14%, Valid: 66.08%, Test: 60.42%
Hits@30
Run: 02, Epoch: 120, Loss: 0.1730, Train: 77.82%, Valid: 67.81%, Test: 73.67%
---
Hits@10
Run: 02, Epoch: 125, Loss: 0.1708, Train: 70.70%, Valid: 60.64%, Test: 34.37%
Hits@20
Run: 02, Epoch: 125, Loss: 0.1708, Train: 76.09%, Valid: 66.09%, Test: 59.37%
Hits@30
Run: 02, Epoch: 125, Loss: 0.1708, Train: 77.63%, Valid: 67.69%, Test: 72.56%
---
Hits@10
Run: 02, Epoch: 130, Loss: 0.1695, Train: 70.65%, Valid: 60.60%, Test: 54.58%
Hits@20
Run: 02, Epoch: 130, Loss: 0.1695, Train: 75.04%, Valid: 65.00%, Test: 66.62%
Hits@30
Run: 02, Epoch: 130, Loss: 0.1695,

Hits@10
Run: 03, Epoch: 70, Loss: 0.1887, Train: 69.07%, Valid: 59.64%, Test: 47.13%
Hits@20
Run: 03, Epoch: 70, Loss: 0.1887, Train: 70.92%, Valid: 61.46%, Test: 61.89%
Hits@30
Run: 03, Epoch: 70, Loss: 0.1887, Train: 73.62%, Valid: 64.19%, Test: 71.68%
---
Hits@10
Run: 03, Epoch: 75, Loss: 0.1879, Train: 66.84%, Valid: 57.51%, Test: 55.37%
Hits@20
Run: 03, Epoch: 75, Loss: 0.1879, Train: 71.31%, Valid: 61.88%, Test: 67.35%
Hits@30
Run: 03, Epoch: 75, Loss: 0.1879, Train: 73.43%, Valid: 64.01%, Test: 73.77%
---
Hits@10
Run: 03, Epoch: 80, Loss: 0.1831, Train: 67.86%, Valid: 58.43%, Test: 34.04%
Hits@20
Run: 03, Epoch: 80, Loss: 0.1831, Train: 74.05%, Valid: 64.56%, Test: 53.36%
Hits@30
Run: 03, Epoch: 80, Loss: 0.1831, Train: 76.38%, Valid: 66.87%, Test: 60.56%
---
Hits@10
Run: 03, Epoch: 85, Loss: 0.1813, Train: 64.71%, Valid: 54.88%, Test: 36.60%
Hits@20
Run: 03, Epoch: 85, Loss: 0.1813, Train: 71.55%, Valid: 61.77%, Test: 54.50%
Hits@30
Run: 03, Epoch: 85, Loss: 0.1813, Train: 74.3

Hits@10
Run: 04, Epoch: 25, Loss: 0.2587, Train: 45.10%, Valid: 38.81%, Test: 17.10%
Hits@20
Run: 04, Epoch: 25, Loss: 0.2587, Train: 51.21%, Valid: 44.40%, Test: 22.55%
Hits@30
Run: 04, Epoch: 25, Loss: 0.2587, Train: 53.62%, Valid: 46.65%, Test: 31.26%
---
Hits@10
Run: 04, Epoch: 30, Loss: 0.2421, Train: 48.08%, Valid: 41.23%, Test: 12.40%
Hits@20
Run: 04, Epoch: 30, Loss: 0.2421, Train: 55.58%, Valid: 48.05%, Test: 22.29%
Hits@30
Run: 04, Epoch: 30, Loss: 0.2421, Train: 59.20%, Valid: 51.34%, Test: 28.95%
---
Hits@10
Run: 04, Epoch: 35, Loss: 0.2262, Train: 49.76%, Valid: 42.43%, Test: 12.01%
Hits@20
Run: 04, Epoch: 35, Loss: 0.2262, Train: 59.45%, Valid: 51.62%, Test: 28.89%
Hits@30
Run: 04, Epoch: 35, Loss: 0.2262, Train: 63.70%, Valid: 55.55%, Test: 35.07%
---
Hits@10
Run: 04, Epoch: 40, Loss: 0.2201, Train: 52.06%, Valid: 44.35%, Test: 24.48%
Hits@20
Run: 04, Epoch: 40, Loss: 0.2201, Train: 60.34%, Valid: 52.19%, Test: 37.59%
Hits@30
Run: 04, Epoch: 40, Loss: 0.2201, Train: 65.6

Hits@10
Run: 04, Epoch: 185, Loss: 0.1612, Train: 62.35%, Valid: 52.15%, Test: 15.61%
Hits@20
Run: 04, Epoch: 185, Loss: 0.1612, Train: 76.13%, Valid: 65.63%, Test: 52.28%
Hits@30
Run: 04, Epoch: 185, Loss: 0.1612, Train: 78.74%, Valid: 68.36%, Test: 65.22%
---
Hits@10
Run: 04, Epoch: 190, Loss: 0.1611, Train: 70.79%, Valid: 60.61%, Test: 41.04%
Hits@20
Run: 04, Epoch: 190, Loss: 0.1611, Train: 75.32%, Valid: 65.08%, Test: 56.46%
Hits@30
Run: 04, Epoch: 190, Loss: 0.1611, Train: 78.17%, Valid: 68.05%, Test: 68.73%
---
Hits@10
Run: 04, Epoch: 195, Loss: 0.1606, Train: 72.98%, Valid: 62.42%, Test: 30.33%
Hits@20
Run: 04, Epoch: 195, Loss: 0.1606, Train: 75.48%, Valid: 65.02%, Test: 53.22%
Hits@30
Run: 04, Epoch: 195, Loss: 0.1606, Train: 78.54%, Valid: 68.14%, Test: 72.10%
---
Hits@10
Run: 04, Epoch: 200, Loss: 0.1611, Train: 70.38%, Valid: 59.79%, Test: 17.86%
Hits@20
Run: 04, Epoch: 200, Loss: 0.1611, Train: 77.89%, Valid: 67.47%, Test: 41.30%
Hits@30
Run: 04, Epoch: 200, Loss: 0.1611,

Hits@10
Run: 05, Epoch: 140, Loss: 0.1659, Train: 67.04%, Valid: 56.89%, Test: 45.50%
Hits@20
Run: 05, Epoch: 140, Loss: 0.1659, Train: 75.29%, Valid: 65.24%, Test: 61.28%
Hits@30
Run: 05, Epoch: 140, Loss: 0.1659, Train: 77.31%, Valid: 67.32%, Test: 68.46%
---
Hits@10
Run: 05, Epoch: 145, Loss: 0.1659, Train: 68.87%, Valid: 58.69%, Test: 47.00%
Hits@20
Run: 05, Epoch: 145, Loss: 0.1659, Train: 77.46%, Valid: 67.47%, Test: 65.38%
Hits@30
Run: 05, Epoch: 145, Loss: 0.1659, Train: 78.90%, Valid: 68.86%, Test: 76.14%
---
Hits@10
Run: 05, Epoch: 150, Loss: 0.1653, Train: 65.31%, Valid: 55.06%, Test: 43.67%
Hits@20
Run: 05, Epoch: 150, Loss: 0.1653, Train: 76.54%, Valid: 66.41%, Test: 61.01%
Hits@30
Run: 05, Epoch: 150, Loss: 0.1653, Train: 77.94%, Valid: 67.84%, Test: 68.15%
---
Hits@10
Run: 05, Epoch: 155, Loss: 0.1642, Train: 69.79%, Valid: 59.43%, Test: 33.61%
Hits@20
Run: 05, Epoch: 155, Loss: 0.1642, Train: 74.57%, Valid: 64.28%, Test: 50.39%
Hits@30
Run: 05, Epoch: 155, Loss: 0.1642,

Hits@10
Run: 06, Epoch: 95, Loss: 0.1774, Train: 58.40%, Valid: 48.92%, Test: 31.55%
Hits@20
Run: 06, Epoch: 95, Loss: 0.1774, Train: 70.40%, Valid: 60.60%, Test: 47.36%
Hits@30
Run: 06, Epoch: 95, Loss: 0.1774, Train: 73.98%, Valid: 64.27%, Test: 59.70%
---
Hits@10
Run: 06, Epoch: 100, Loss: 0.1769, Train: 58.29%, Valid: 48.48%, Test: 38.78%
Hits@20
Run: 06, Epoch: 100, Loss: 0.1769, Train: 70.04%, Valid: 59.98%, Test: 53.96%
Hits@30
Run: 06, Epoch: 100, Loss: 0.1769, Train: 73.87%, Valid: 63.90%, Test: 61.69%
---
Hits@10
Run: 06, Epoch: 105, Loss: 0.1747, Train: 65.08%, Valid: 55.16%, Test: 27.60%
Hits@20
Run: 06, Epoch: 105, Loss: 0.1747, Train: 71.22%, Valid: 61.30%, Test: 47.75%
Hits@30
Run: 06, Epoch: 105, Loss: 0.1747, Train: 74.79%, Valid: 64.93%, Test: 62.77%
---
Hits@10
Run: 06, Epoch: 110, Loss: 0.1740, Train: 61.56%, Valid: 51.49%, Test: 39.94%
Hits@20
Run: 06, Epoch: 110, Loss: 0.1740, Train: 69.11%, Valid: 58.98%, Test: 55.62%
Hits@30
Run: 06, Epoch: 110, Loss: 0.1740, Tr

Hits@10
Run: 07, Epoch: 50, Loss: 0.2025, Train: 62.15%, Valid: 53.66%, Test: 40.56%
Hits@20
Run: 07, Epoch: 50, Loss: 0.2025, Train: 67.98%, Valid: 59.17%, Test: 48.72%
Hits@30
Run: 07, Epoch: 50, Loss: 0.2025, Train: 69.81%, Valid: 60.94%, Test: 58.40%
---
Hits@10
Run: 07, Epoch: 55, Loss: 0.1981, Train: 65.07%, Valid: 55.97%, Test: 34.49%
Hits@20
Run: 07, Epoch: 55, Loss: 0.1981, Train: 69.96%, Valid: 60.70%, Test: 49.64%
Hits@30
Run: 07, Epoch: 55, Loss: 0.1981, Train: 72.44%, Valid: 63.17%, Test: 58.74%
---
Hits@10
Run: 07, Epoch: 60, Loss: 0.1934, Train: 57.83%, Valid: 48.78%, Test: 23.54%
Hits@20
Run: 07, Epoch: 60, Loss: 0.1934, Train: 66.20%, Valid: 56.81%, Test: 44.54%
Hits@30
Run: 07, Epoch: 60, Loss: 0.1934, Train: 70.03%, Valid: 60.64%, Test: 52.03%
---
Hits@10
Run: 07, Epoch: 65, Loss: 0.1893, Train: 63.55%, Valid: 54.33%, Test: 22.71%
Hits@20
Run: 07, Epoch: 65, Loss: 0.1893, Train: 69.56%, Valid: 60.22%, Test: 45.63%
Hits@30
Run: 07, Epoch: 65, Loss: 0.1893, Train: 72.2

Hits@10
Run: 08, Epoch: 05, Loss: 0.5729, Train: 8.48%, Valid: 7.40%, Test: 3.76%
Hits@20
Run: 08, Epoch: 05, Loss: 0.5729, Train: 12.38%, Valid: 11.00%, Test: 7.96%
Hits@30
Run: 08, Epoch: 05, Loss: 0.5729, Train: 15.09%, Valid: 13.54%, Test: 8.95%
---
Hits@10
Run: 08, Epoch: 10, Loss: 0.4171, Train: 19.79%, Valid: 17.24%, Test: 6.62%
Hits@20
Run: 08, Epoch: 10, Loss: 0.4171, Train: 27.04%, Valid: 23.94%, Test: 9.82%
Hits@30
Run: 08, Epoch: 10, Loss: 0.4171, Train: 30.04%, Valid: 26.96%, Test: 11.53%
---
Hits@10
Run: 08, Epoch: 15, Loss: 0.3328, Train: 26.79%, Valid: 23.20%, Test: 13.47%
Hits@20
Run: 08, Epoch: 15, Loss: 0.3328, Train: 36.82%, Valid: 32.15%, Test: 18.36%
Hits@30
Run: 08, Epoch: 15, Loss: 0.3328, Train: 43.41%, Valid: 38.43%, Test: 22.17%
---
Hits@10
Run: 08, Epoch: 20, Loss: 0.2920, Train: 36.83%, Valid: 31.76%, Test: 8.78%
Hits@20
Run: 08, Epoch: 20, Loss: 0.2920, Train: 42.42%, Valid: 36.91%, Test: 15.16%
Hits@30
Run: 08, Epoch: 20, Loss: 0.2920, Train: 47.16%, Vali

Hits@10
Run: 08, Epoch: 165, Loss: 0.1647, Train: 56.74%, Valid: 46.62%, Test: 34.66%
Hits@20
Run: 08, Epoch: 165, Loss: 0.1647, Train: 76.31%, Valid: 65.89%, Test: 58.99%
Hits@30
Run: 08, Epoch: 165, Loss: 0.1647, Train: 78.69%, Valid: 68.50%, Test: 69.46%
---
Hits@10
Run: 08, Epoch: 170, Loss: 0.1640, Train: 67.40%, Valid: 56.38%, Test: 36.44%
Hits@20
Run: 08, Epoch: 170, Loss: 0.1640, Train: 76.84%, Valid: 66.34%, Test: 59.41%
Hits@30
Run: 08, Epoch: 170, Loss: 0.1640, Train: 79.20%, Valid: 68.89%, Test: 75.44%
---
Hits@10
Run: 08, Epoch: 175, Loss: 0.1619, Train: 61.30%, Valid: 50.37%, Test: 35.21%
Hits@20
Run: 08, Epoch: 175, Loss: 0.1619, Train: 73.37%, Valid: 62.67%, Test: 56.97%
Hits@30
Run: 08, Epoch: 175, Loss: 0.1619, Train: 75.80%, Valid: 65.26%, Test: 66.51%
---
Hits@10
Run: 08, Epoch: 180, Loss: 0.1635, Train: 49.16%, Valid: 39.11%, Test: 24.43%
Hits@20
Run: 08, Epoch: 180, Loss: 0.1635, Train: 73.30%, Valid: 62.86%, Test: 43.83%
Hits@30
Run: 08, Epoch: 180, Loss: 0.1635,

Hits@10
Run: 09, Epoch: 120, Loss: 0.1727, Train: 67.05%, Valid: 56.85%, Test: 58.11%
Hits@20
Run: 09, Epoch: 120, Loss: 0.1727, Train: 74.64%, Valid: 64.64%, Test: 66.93%
Hits@30
Run: 09, Epoch: 120, Loss: 0.1727, Train: 76.90%, Valid: 66.93%, Test: 77.01%
---
Hits@10
Run: 09, Epoch: 125, Loss: 0.1721, Train: 57.61%, Valid: 47.98%, Test: 52.36%
Hits@20
Run: 09, Epoch: 125, Loss: 0.1721, Train: 75.40%, Valid: 65.34%, Test: 64.08%
Hits@30
Run: 09, Epoch: 125, Loss: 0.1721, Train: 76.44%, Valid: 66.42%, Test: 73.89%
---
Hits@10
Run: 09, Epoch: 130, Loss: 0.1702, Train: 71.01%, Valid: 61.04%, Test: 35.24%
Hits@20
Run: 09, Epoch: 130, Loss: 0.1702, Train: 75.12%, Valid: 65.14%, Test: 56.03%
Hits@30
Run: 09, Epoch: 130, Loss: 0.1702, Train: 78.41%, Valid: 68.46%, Test: 69.36%
---
Hits@10
Run: 09, Epoch: 135, Loss: 0.1683, Train: 66.80%, Valid: 56.53%, Test: 39.10%
Hits@20
Run: 09, Epoch: 135, Loss: 0.1683, Train: 76.64%, Valid: 66.50%, Test: 54.59%
Hits@30
Run: 09, Epoch: 135, Loss: 0.1683,

Hits@10
Run: 10, Epoch: 75, Loss: 0.1841, Train: 64.81%, Valid: 55.39%, Test: 31.10%
Hits@20
Run: 10, Epoch: 75, Loss: 0.1841, Train: 70.51%, Valid: 60.92%, Test: 51.19%
Hits@30
Run: 10, Epoch: 75, Loss: 0.1841, Train: 72.00%, Valid: 62.33%, Test: 56.20%
---
Hits@10
Run: 10, Epoch: 80, Loss: 0.1828, Train: 64.71%, Valid: 55.57%, Test: 26.72%
Hits@20
Run: 10, Epoch: 80, Loss: 0.1828, Train: 69.61%, Valid: 60.17%, Test: 43.14%
Hits@30
Run: 10, Epoch: 80, Loss: 0.1828, Train: 73.32%, Valid: 63.80%, Test: 54.10%
---
Hits@10
Run: 10, Epoch: 85, Loss: 0.1789, Train: 67.36%, Valid: 57.65%, Test: 27.26%
Hits@20
Run: 10, Epoch: 85, Loss: 0.1789, Train: 71.54%, Valid: 61.88%, Test: 52.40%
Hits@30
Run: 10, Epoch: 85, Loss: 0.1789, Train: 75.14%, Valid: 65.53%, Test: 68.83%
---
Hits@10
Run: 10, Epoch: 90, Loss: 0.1781, Train: 61.75%, Valid: 51.99%, Test: 24.33%
Hits@20
Run: 10, Epoch: 90, Loss: 0.1781, Train: 71.96%, Valid: 62.19%, Test: 41.83%
Hits@30
Run: 10, Epoch: 90, Loss: 0.1781, Train: 74.0

## Graph Sage with Nueral Link Predictor

In [13]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x


class NeuralLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(NeuralLinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x).squeeze()


def train(model, predictor, x, adj_t, split_edge, optimizer, batch_size):

    row, col, _ = adj_t.coo()
    edge_index = torch.stack([col, row], dim=0)

    model.train()
    predictor.train()

    pos_train_edge = split_edge['train']['edge'].to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size,
                           shuffle=True):
        optimizer.zero_grad()

        h = model(x, adj_t)

        edge = pos_train_edge[perm].t()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(model, predictor, x, adj_t, split_edge, evaluator, batch_size):
    model.eval()
    predictor.eval()

    h = model(x, adj_t)

    pos_train_edge = split_edge['eval_train']['edge'].to(x.device)
    pos_valid_edge = split_edge['valid']['edge'].to(x.device)
    neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device)
    pos_test_edge = split_edge['test']['edge'].to(x.device)
    neg_test_edge = split_edge['test']['edge_neg'].to(x.device)

    pos_train_preds = []
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size):
        edge = pos_train_edge[perm].t()
        pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_train_pred = torch.cat(pos_train_preds, dim=0)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [10, 20, 30]:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results


def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)')
    parser.add_argument('-f')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_sage', default = True)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi',
                                     transform=T.ToSparseTensor())
    data = dataset[0]
    adj_t = data.adj_t.to(device)

    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    if args.use_sage:
        model = SAGE(args.hidden_channels, args.hidden_channels,
                     args.hidden_channels, args.num_layers,
                     args.dropout).to(device)
    else:
        model = GCN(args.hidden_channels, args.hidden_channels,
                    args.hidden_channels, args.num_layers,
                    args.dropout).to(device)

    emb = torch.nn.Embedding(data.adj_t.size(0),
                             args.hidden_channels).to(device)
    predictor = NeuralLinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-ddi')
    Logger_Models_Models = {
        'Hits@10': Logger_Models_Model(args.runs, args),
        'Hits@20': Logger_Models_Model(args.runs, args),
        'Hits@30': Logger_Models_Model(args.runs, args),
    }

    for run in range(args.runs):
        torch.nn.init.xavier_uniform_(emb.weight)
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(model.parameters()) + list(emb.parameters()) +
            list(predictor.parameters()), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, emb.weight, adj_t, split_edge,
                         optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, predictor, emb.weight, adj_t, split_edge,
                               evaluator, args.batch_size)
                for key, result in results.items():
                    Logger_Models_Models[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in Logger_Models_Models.keys():
            print(key)
            Logger_Models_Models[key].print_statistics(run)

    for key in Logger_Models_Models.keys():
        print(key)
        Logger_Models_Models[key].print_statistics()


if __name__ == "__main__":
    main()


Namespace(f='C:\\Users\\b04753yr\\AppData\\Roaming\\jupyter\\runtime\\kernel-30e704de-a8cc-40a4-8675-f15ec8e16fa1.json', device=0, log_steps=1, use_sage=True, num_layers=2, hidden_channels=256, dropout=0.5, batch_size=65536, lr=0.005, epochs=200, eval_steps=5, runs=10)
Hits@10
Run: 01, Epoch: 05, Loss: 0.5677, Train: 12.61%, Valid: 11.24%, Test: 4.35%
Hits@20
Run: 01, Epoch: 05, Loss: 0.5677, Train: 19.14%, Valid: 17.31%, Test: 8.01%
Hits@30
Run: 01, Epoch: 05, Loss: 0.5677, Train: 21.72%, Valid: 19.72%, Test: 12.40%
---
Hits@10
Run: 01, Epoch: 10, Loss: 0.4075, Train: 25.33%, Valid: 22.36%, Test: 15.21%
Hits@20
Run: 01, Epoch: 10, Loss: 0.4075, Train: 29.39%, Valid: 26.00%, Test: 19.75%
Hits@30
Run: 01, Epoch: 10, Loss: 0.4075, Train: 34.37%, Valid: 30.75%, Test: 23.68%
---
Hits@10
Run: 01, Epoch: 15, Loss: 0.3307, Train: 23.80%, Valid: 20.49%, Test: 4.21%
Hits@20
Run: 01, Epoch: 15, Loss: 0.3307, Train: 31.31%, Valid: 27.14%, Test: 9.60%
Hits@30
Run: 01, Epoch: 15, Loss: 0.3307, Trai

Hits@10
Run: 01, Epoch: 160, Loss: 0.1656, Train: 66.97%, Valid: 56.37%, Test: 29.31%
Hits@20
Run: 01, Epoch: 160, Loss: 0.1656, Train: 72.55%, Valid: 62.01%, Test: 54.29%
Hits@30
Run: 01, Epoch: 160, Loss: 0.1656, Train: 76.08%, Valid: 65.60%, Test: 66.65%
---
Hits@10
Run: 01, Epoch: 165, Loss: 0.1643, Train: 70.27%, Valid: 59.67%, Test: 28.09%
Hits@20
Run: 01, Epoch: 165, Loss: 0.1643, Train: 74.45%, Valid: 63.88%, Test: 50.63%
Hits@30
Run: 01, Epoch: 165, Loss: 0.1643, Train: 77.84%, Valid: 67.47%, Test: 63.10%
---
Hits@10
Run: 01, Epoch: 170, Loss: 0.1643, Train: 67.85%, Valid: 57.23%, Test: 40.50%
Hits@20
Run: 01, Epoch: 170, Loss: 0.1643, Train: 74.45%, Valid: 63.93%, Test: 65.19%
Hits@30
Run: 01, Epoch: 170, Loss: 0.1643, Train: 77.64%, Valid: 67.30%, Test: 75.13%
---
Hits@10
Run: 01, Epoch: 175, Loss: 0.1639, Train: 59.60%, Valid: 48.86%, Test: 39.41%
Hits@20
Run: 01, Epoch: 175, Loss: 0.1639, Train: 68.46%, Valid: 57.47%, Test: 52.13%
Hits@30
Run: 01, Epoch: 175, Loss: 0.1639,

Hits@10
Run: 02, Epoch: 115, Loss: 0.1733, Train: 69.93%, Valid: 60.07%, Test: 49.35%
Hits@20
Run: 02, Epoch: 115, Loss: 0.1733, Train: 74.29%, Valid: 64.49%, Test: 64.70%
Hits@30
Run: 02, Epoch: 115, Loss: 0.1733, Train: 76.73%, Valid: 66.95%, Test: 72.69%
---
Hits@10
Run: 02, Epoch: 120, Loss: 0.1707, Train: 68.96%, Valid: 58.72%, Test: 45.80%
Hits@20
Run: 02, Epoch: 120, Loss: 0.1707, Train: 74.97%, Valid: 64.95%, Test: 57.34%
Hits@30
Run: 02, Epoch: 120, Loss: 0.1707, Train: 76.93%, Valid: 66.93%, Test: 69.29%
---
Hits@10
Run: 02, Epoch: 125, Loss: 0.1705, Train: 68.50%, Valid: 58.39%, Test: 31.91%
Hits@20
Run: 02, Epoch: 125, Loss: 0.1705, Train: 74.91%, Valid: 64.84%, Test: 60.31%
Hits@30
Run: 02, Epoch: 125, Loss: 0.1705, Train: 76.41%, Valid: 66.36%, Test: 65.03%
---
Hits@10
Run: 02, Epoch: 130, Loss: 0.1695, Train: 70.77%, Valid: 60.75%, Test: 47.35%
Hits@20
Run: 02, Epoch: 130, Loss: 0.1695, Train: 76.42%, Valid: 66.42%, Test: 62.23%
Hits@30
Run: 02, Epoch: 130, Loss: 0.1695,

Hits@10
Run: 03, Epoch: 70, Loss: 0.1884, Train: 57.43%, Valid: 48.34%, Test: 42.69%
Hits@20
Run: 03, Epoch: 70, Loss: 0.1884, Train: 68.50%, Valid: 59.11%, Test: 57.95%
Hits@30
Run: 03, Epoch: 70, Loss: 0.1884, Train: 71.65%, Valid: 62.18%, Test: 66.87%
---
Hits@10
Run: 03, Epoch: 75, Loss: 0.1849, Train: 61.30%, Valid: 52.09%, Test: 36.94%
Hits@20
Run: 03, Epoch: 75, Loss: 0.1849, Train: 70.77%, Valid: 61.30%, Test: 53.38%
Hits@30
Run: 03, Epoch: 75, Loss: 0.1849, Train: 72.48%, Valid: 62.93%, Test: 59.92%
---
Hits@10
Run: 03, Epoch: 80, Loss: 0.1835, Train: 58.58%, Valid: 48.88%, Test: 15.34%
Hits@20
Run: 03, Epoch: 80, Loss: 0.1835, Train: 69.06%, Valid: 59.14%, Test: 32.21%
Hits@30
Run: 03, Epoch: 80, Loss: 0.1835, Train: 71.61%, Valid: 61.79%, Test: 40.26%
---
Hits@10
Run: 03, Epoch: 85, Loss: 0.1826, Train: 63.93%, Valid: 54.16%, Test: 44.75%
Hits@20
Run: 03, Epoch: 85, Loss: 0.1826, Train: 72.55%, Valid: 63.02%, Test: 59.03%
Hits@30
Run: 03, Epoch: 85, Loss: 0.1826, Train: 74.5

Hits@10
Run: 04, Epoch: 25, Loss: 0.2561, Train: 45.34%, Valid: 39.01%, Test: 14.13%
Hits@20
Run: 04, Epoch: 25, Loss: 0.2561, Train: 51.36%, Valid: 44.53%, Test: 19.78%
Hits@30
Run: 04, Epoch: 25, Loss: 0.2561, Train: 56.07%, Valid: 48.95%, Test: 28.28%
---
Hits@10
Run: 04, Epoch: 30, Loss: 0.2389, Train: 51.04%, Valid: 44.03%, Test: 13.08%
Hits@20
Run: 04, Epoch: 30, Loss: 0.2389, Train: 57.86%, Valid: 50.34%, Test: 19.50%
Hits@30
Run: 04, Epoch: 30, Loss: 0.2389, Train: 61.43%, Valid: 53.81%, Test: 27.05%
---
Hits@10
Run: 04, Epoch: 35, Loss: 0.2263, Train: 59.19%, Valid: 51.30%, Test: 20.60%
Hits@20
Run: 04, Epoch: 35, Loss: 0.2263, Train: 62.78%, Valid: 54.75%, Test: 39.40%
Hits@30
Run: 04, Epoch: 35, Loss: 0.2263, Train: 65.04%, Valid: 56.83%, Test: 46.55%
---
Hits@10
Run: 04, Epoch: 40, Loss: 0.2170, Train: 53.06%, Valid: 45.61%, Test: 21.98%
Hits@20
Run: 04, Epoch: 40, Loss: 0.2170, Train: 62.88%, Valid: 54.55%, Test: 31.26%
Hits@30
Run: 04, Epoch: 40, Loss: 0.2170, Train: 66.2

Hits@10
Run: 04, Epoch: 185, Loss: 0.1615, Train: 64.55%, Valid: 53.89%, Test: 33.44%
Hits@20
Run: 04, Epoch: 185, Loss: 0.1615, Train: 75.86%, Valid: 65.45%, Test: 58.97%
Hits@30
Run: 04, Epoch: 185, Loss: 0.1615, Train: 78.44%, Valid: 68.09%, Test: 72.76%
---
Hits@10
Run: 04, Epoch: 190, Loss: 0.1608, Train: 70.98%, Valid: 60.34%, Test: 40.98%
Hits@20
Run: 04, Epoch: 190, Loss: 0.1608, Train: 75.51%, Valid: 65.03%, Test: 63.31%
Hits@30
Run: 04, Epoch: 190, Loss: 0.1608, Train: 78.55%, Valid: 68.18%, Test: 77.47%
---
Hits@10
Run: 04, Epoch: 195, Loss: 0.1592, Train: 71.47%, Valid: 60.71%, Test: 45.47%
Hits@20
Run: 04, Epoch: 195, Loss: 0.1592, Train: 76.90%, Valid: 66.43%, Test: 61.96%
Hits@30
Run: 04, Epoch: 195, Loss: 0.1592, Train: 79.45%, Valid: 69.16%, Test: 69.03%
---
Hits@10
Run: 04, Epoch: 200, Loss: 0.1604, Train: 62.57%, Valid: 51.90%, Test: 35.04%
Hits@20
Run: 04, Epoch: 200, Loss: 0.1604, Train: 76.28%, Valid: 65.63%, Test: 56.40%
Hits@30
Run: 04, Epoch: 200, Loss: 0.1604,

Hits@10
Run: 05, Epoch: 140, Loss: 0.1664, Train: 66.30%, Valid: 55.94%, Test: 7.40%
Hits@20
Run: 05, Epoch: 140, Loss: 0.1664, Train: 74.96%, Valid: 64.67%, Test: 30.13%
Hits@30
Run: 05, Epoch: 140, Loss: 0.1664, Train: 77.44%, Valid: 67.23%, Test: 55.53%
---
Hits@10
Run: 05, Epoch: 145, Loss: 0.1663, Train: 70.31%, Valid: 60.08%, Test: 42.48%
Hits@20
Run: 05, Epoch: 145, Loss: 0.1663, Train: 76.64%, Valid: 66.46%, Test: 71.54%
Hits@30
Run: 05, Epoch: 145, Loss: 0.1663, Train: 77.94%, Valid: 67.80%, Test: 76.32%
---
Hits@10
Run: 05, Epoch: 150, Loss: 0.1639, Train: 55.59%, Valid: 45.36%, Test: 35.87%
Hits@20
Run: 05, Epoch: 150, Loss: 0.1639, Train: 71.58%, Valid: 60.93%, Test: 53.15%
Hits@30
Run: 05, Epoch: 150, Loss: 0.1639, Train: 77.01%, Valid: 66.62%, Test: 61.61%
---
Hits@10
Run: 05, Epoch: 155, Loss: 0.1660, Train: 69.59%, Valid: 59.33%, Test: 41.43%
Hits@20
Run: 05, Epoch: 155, Loss: 0.1660, Train: 74.90%, Valid: 64.67%, Test: 66.35%
Hits@30
Run: 05, Epoch: 155, Loss: 0.1660, 

Hits@10
Run: 06, Epoch: 95, Loss: 0.1787, Train: 67.61%, Valid: 57.57%, Test: 35.31%
Hits@20
Run: 06, Epoch: 95, Loss: 0.1787, Train: 72.41%, Valid: 62.45%, Test: 51.61%
Hits@30
Run: 06, Epoch: 95, Loss: 0.1787, Train: 75.54%, Valid: 65.79%, Test: 60.52%
---
Hits@10
Run: 06, Epoch: 100, Loss: 0.1778, Train: 64.47%, Valid: 54.82%, Test: 30.68%
Hits@20
Run: 06, Epoch: 100, Loss: 0.1778, Train: 74.88%, Valid: 65.19%, Test: 45.57%
Hits@30
Run: 06, Epoch: 100, Loss: 0.1778, Train: 76.56%, Valid: 66.88%, Test: 60.59%
---
Hits@10
Run: 06, Epoch: 105, Loss: 0.1751, Train: 66.57%, Valid: 56.71%, Test: 23.83%
Hits@20
Run: 06, Epoch: 105, Loss: 0.1751, Train: 75.06%, Valid: 65.18%, Test: 55.74%
Hits@30
Run: 06, Epoch: 105, Loss: 0.1751, Train: 76.52%, Valid: 66.60%, Test: 66.95%
---
Hits@10
Run: 06, Epoch: 110, Loss: 0.1728, Train: 56.27%, Valid: 46.41%, Test: 31.58%
Hits@20
Run: 06, Epoch: 110, Loss: 0.1728, Train: 66.41%, Valid: 56.14%, Test: 42.41%
Hits@30
Run: 06, Epoch: 110, Loss: 0.1728, Tr

Hits@10
Run: 07, Epoch: 50, Loss: 0.2032, Train: 58.81%, Valid: 50.32%, Test: 28.70%
Hits@20
Run: 07, Epoch: 50, Loss: 0.2032, Train: 67.00%, Valid: 58.16%, Test: 52.21%
Hits@30
Run: 07, Epoch: 50, Loss: 0.2032, Train: 69.70%, Valid: 60.70%, Test: 58.19%
---
Hits@10
Run: 07, Epoch: 55, Loss: 0.1960, Train: 61.24%, Valid: 52.37%, Test: 35.98%
Hits@20
Run: 07, Epoch: 55, Loss: 0.1960, Train: 67.97%, Valid: 58.93%, Test: 50.43%
Hits@30
Run: 07, Epoch: 55, Loss: 0.1960, Train: 70.36%, Valid: 61.22%, Test: 62.71%
---
Hits@10
Run: 07, Epoch: 60, Loss: 0.1938, Train: 66.72%, Valid: 57.47%, Test: 27.87%
Hits@20
Run: 07, Epoch: 60, Loss: 0.1938, Train: 68.95%, Valid: 59.57%, Test: 43.71%
Hits@30
Run: 07, Epoch: 60, Loss: 0.1938, Train: 72.08%, Valid: 62.62%, Test: 59.08%
---
Hits@10
Run: 07, Epoch: 65, Loss: 0.1879, Train: 66.05%, Valid: 56.78%, Test: 17.87%
Hits@20
Run: 07, Epoch: 65, Loss: 0.1879, Train: 71.03%, Valid: 61.67%, Test: 37.11%
Hits@30
Run: 07, Epoch: 65, Loss: 0.1879, Train: 72.7

Hits@10
Run: 08, Epoch: 05, Loss: 0.5707, Train: 11.54%, Valid: 10.39%, Test: 4.85%
Hits@20
Run: 08, Epoch: 05, Loss: 0.5707, Train: 15.36%, Valid: 13.96%, Test: 8.35%
Hits@30
Run: 08, Epoch: 05, Loss: 0.5707, Train: 17.56%, Valid: 15.98%, Test: 11.31%
---
Hits@10
Run: 08, Epoch: 10, Loss: 0.4109, Train: 21.66%, Valid: 18.84%, Test: 8.12%
Hits@20
Run: 08, Epoch: 10, Loss: 0.4109, Train: 25.48%, Valid: 22.39%, Test: 11.31%
Hits@30
Run: 08, Epoch: 10, Loss: 0.4109, Train: 29.70%, Valid: 26.34%, Test: 16.53%
---
Hits@10
Run: 08, Epoch: 15, Loss: 0.3314, Train: 31.36%, Valid: 27.18%, Test: 11.48%
Hits@20
Run: 08, Epoch: 15, Loss: 0.3314, Train: 37.59%, Valid: 32.91%, Test: 18.98%
Hits@30
Run: 08, Epoch: 15, Loss: 0.3314, Train: 43.19%, Valid: 38.22%, Test: 23.51%
---
Hits@10
Run: 08, Epoch: 20, Loss: 0.2923, Train: 46.62%, Valid: 40.81%, Test: 10.44%
Hits@20
Run: 08, Epoch: 20, Loss: 0.2923, Train: 50.56%, Valid: 44.50%, Test: 18.86%
Hits@30
Run: 08, Epoch: 20, Loss: 0.2923, Train: 54.47%,

Hits@10
Run: 08, Epoch: 165, Loss: 0.1632, Train: 56.39%, Valid: 45.71%, Test: 23.38%
Hits@20
Run: 08, Epoch: 165, Loss: 0.1632, Train: 72.38%, Valid: 61.48%, Test: 37.51%
Hits@30
Run: 08, Epoch: 165, Loss: 0.1632, Train: 76.30%, Valid: 65.70%, Test: 48.97%
---
Hits@10
Run: 08, Epoch: 170, Loss: 0.1637, Train: 72.20%, Valid: 61.83%, Test: 24.63%
Hits@20
Run: 08, Epoch: 170, Loss: 0.1637, Train: 77.48%, Valid: 67.24%, Test: 55.54%
Hits@30
Run: 08, Epoch: 170, Loss: 0.1637, Train: 79.05%, Valid: 68.83%, Test: 70.61%
---
Hits@10
Run: 08, Epoch: 175, Loss: 0.1641, Train: 65.01%, Valid: 54.94%, Test: 36.80%
Hits@20
Run: 08, Epoch: 175, Loss: 0.1641, Train: 75.86%, Valid: 65.62%, Test: 53.22%
Hits@30
Run: 08, Epoch: 175, Loss: 0.1641, Train: 78.59%, Valid: 68.38%, Test: 65.19%
---
Hits@10
Run: 08, Epoch: 180, Loss: 0.1627, Train: 63.68%, Valid: 52.25%, Test: 31.46%
Hits@20
Run: 08, Epoch: 180, Loss: 0.1627, Train: 74.49%, Valid: 63.66%, Test: 48.35%
Hits@30
Run: 08, Epoch: 180, Loss: 0.1627,

Hits@10
Run: 09, Epoch: 120, Loss: 0.1713, Train: 66.18%, Valid: 56.16%, Test: 37.63%
Hits@20
Run: 09, Epoch: 120, Loss: 0.1713, Train: 73.33%, Valid: 63.30%, Test: 51.34%
Hits@30
Run: 09, Epoch: 120, Loss: 0.1713, Train: 75.91%, Valid: 65.86%, Test: 63.41%
---
Hits@10
Run: 09, Epoch: 125, Loss: 0.1707, Train: 55.29%, Valid: 45.30%, Test: 47.57%
Hits@20
Run: 09, Epoch: 125, Loss: 0.1707, Train: 73.62%, Valid: 63.30%, Test: 55.19%
Hits@30
Run: 09, Epoch: 125, Loss: 0.1707, Train: 74.89%, Valid: 64.58%, Test: 65.04%
---
Hits@10
Run: 09, Epoch: 130, Loss: 0.1694, Train: 52.97%, Valid: 43.38%, Test: 34.21%
Hits@20
Run: 09, Epoch: 130, Loss: 0.1694, Train: 71.66%, Valid: 61.45%, Test: 58.05%
Hits@30
Run: 09, Epoch: 130, Loss: 0.1694, Train: 75.47%, Valid: 65.30%, Test: 64.23%
---
Hits@10
Run: 09, Epoch: 135, Loss: 0.1698, Train: 50.76%, Valid: 41.32%, Test: 21.19%
Hits@20
Run: 09, Epoch: 135, Loss: 0.1698, Train: 71.81%, Valid: 61.49%, Test: 36.36%
Hits@30
Run: 09, Epoch: 135, Loss: 0.1698,

Hits@10
Run: 10, Epoch: 75, Loss: 0.1843, Train: 63.78%, Valid: 54.37%, Test: 40.94%
Hits@20
Run: 10, Epoch: 75, Loss: 0.1843, Train: 71.97%, Valid: 62.48%, Test: 52.96%
Hits@30
Run: 10, Epoch: 75, Loss: 0.1843, Train: 73.83%, Valid: 64.29%, Test: 63.30%
---
Hits@10
Run: 10, Epoch: 80, Loss: 0.1832, Train: 69.27%, Valid: 59.87%, Test: 27.52%
Hits@20
Run: 10, Epoch: 80, Loss: 0.1832, Train: 72.23%, Valid: 62.71%, Test: 52.17%
Hits@30
Run: 10, Epoch: 80, Loss: 0.1832, Train: 75.87%, Valid: 66.40%, Test: 65.36%
---
Hits@10
Run: 10, Epoch: 85, Loss: 0.1794, Train: 68.23%, Valid: 58.77%, Test: 31.73%
Hits@20
Run: 10, Epoch: 85, Loss: 0.1794, Train: 73.45%, Valid: 63.99%, Test: 47.94%
Hits@30
Run: 10, Epoch: 85, Loss: 0.1794, Train: 76.10%, Valid: 66.60%, Test: 66.14%
---
Hits@10
Run: 10, Epoch: 90, Loss: 0.1760, Train: 58.45%, Valid: 48.69%, Test: 45.39%
Hits@20
Run: 10, Epoch: 90, Loss: 0.1760, Train: 72.73%, Valid: 62.97%, Test: 65.09%
Hits@30
Run: 10, Epoch: 90, Loss: 0.1760, Train: 76.0

## EDA on Graph

In [None]:
import pandas as pd
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
import networkx as nx
import random
import torch_geometric.transforms as T

from torch import Tensor
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling, convert, to_dense_adj
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.nn.conv import MessagePassing
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
from sklearn.preprocessing import MinMaxScaler


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sweetviz as sv
import spacy
import warnings
warnings.filterwarnings('ignore')
from termcolor import colored
import re
import random
import numpy as np
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
import en_core_web_sm
import seaborn as sns
import xlsxwriter
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
import en_core_web_sm
nlp = en_core_web_sm.load()
import gensim
from gensim import corpora
import gensim, spacy, logging, warnings
import gensim.corpora as corpora
#from gensim.utils import lemmatize, simple_preprocess
from gensim.models import CoherenceModel
import pyLDAvis

In [None]:
node_description = pd.read_csv('dataset/ogbl_ddi/mapping/nodeidx2drugid.csv')
node_info_df = node_description.copy()

In [None]:
drug_description = pd.read_csv('dataset/ogbl_ddi/mapping/ddi_description.csv')
drug_interaction_info_df = drug_description.copy()

In [None]:
node_list_data = node_description['node idx'].drop_duplicates().to_list()

In [None]:
from ogb.linkproppred import PygLinkPropPredDataset

dataset = PygLinkPropPredDataset(name = 'ogbl-ddi') 

split_edge = dataset.get_edge_split()
train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]
data = dataset[0]

### Drug Names

In [None]:
drug_names_1 = drug_description[['first drug id','first drug name']].drop_duplicates()
drug_names_1 = drug_names_1.rename(columns = {'first drug id':'drug id','first drug name':'drug name'})
drug_names_2 = drug_description[['second drug id','second drug name']].drop_duplicates()
drug_names_2 = drug_names_2.rename(columns = {'second drug id':'drug id','second drug name':'drug name'})
total_drug_names = pd.concat([drug_names_1,drug_names_2], axis = 0).drop_duplicates()

In [None]:
node_info_df_with_names = node_info_df.merge(total_drug_names, on = 'drug id', how = 'left')
node_info_df_with_names.head()
node_info_df_with_names = node_info_df_with_names.rename(columns = {'node idx': 'first node','drug id':'first drug id','drug name':'first drug name'})
dataset = PygLinkPropPredDataset(name='ogbl-ddi')
data = dataset[0]
G = convert.to_networkx(data, to_undirected=True)

### Node Neighbours

In [None]:
node_list = list(G.nodes)
nodes_neighbours = {}
for i in node_list:
    nodes_neighbours[i] = [n for n in G.neighbors(i)]  

In [None]:
node_degree = {}
for i in node_list:
    node_degree[i] = G.degree[i]

In [None]:
node_degree_data = pd.DataFrame()
node_degree_data['Nodes'] = node_list_data
node_degree_data['Degree'] = node_degree_data['Nodes'].map(node_degree)

In [None]:
node_degree_data.describe().to_csv('Degree_Stats.txt', sep = '\t')

In [None]:
node_degree_data.to_csv('Node_Degree_Data.txt', index = False, sep = '\t')

In [None]:
node_info_df_with_names

### Mapping node to neighbours

In [None]:
node_info_df_with_names['second node'] = node_info_df_with_names['first node'].map(nodes_neighbours)
node_info_df_with_names = (node_info_df_with_names.join(pd.DataFrame(node_info_df_with_names.pop('second node')
                            .values.tolist())
               .stack()
               .reset_index(level=1, drop=True)
               .rename('second node'))).reset_index(drop=True)
node_info_df_with_names['second node'] = node_info_df_with_names['second node'].astype(int)

In [None]:
adding_second_drug_id = node_info_df_with_names.merge(node_description, left_on = 'second node', right_on = 'node idx')
adding_second_drug_id = adding_second_drug_id.drop(['node idx'], axis = 1)
adding_second_drug_name = adding_second_drug_id.merge(total_drug_names, on = 'drug id', how = 'left')
adding_second_drug_name = adding_second_drug_name.rename(columns = {'drug id':'second drug id','drug name':'second drug name'})

In [None]:
complete_node_interaction_df = adding_second_drug_name.merge(drug_interaction_info_df, 
                              on = ['first drug id','first drug name','second drug id','second drug name'], 
                             how = 'left')
complete_node_interaction_df = complete_node_interaction_df.rename(columns = {'description':'polypharmacy side effect'})

In [None]:
empty_data = complete_node_interaction_df[complete_node_interaction_df['polypharmacy side effect'].isna()]

In [None]:
def assign_reverse_combos(row):
    first_node = row[1]
    second_node = row[4]
    
    global complete_node_interaction_df
    
    get_val = complete_node_interaction_df[(complete_node_interaction_df['first drug id']==second_node) & (complete_node_interaction_df['second drug id']==first_node) ]['polypharmacy side effect']
    get_val = get_val.tolist()[0]
    
    complete_node_interaction_df.loc[(complete_node_interaction_df['second drug id']==second_node) & (complete_node_interaction_df['first drug id']==first_node),'polypharmacy side effect' ] = get_val
    
empty_data.apply(assign_reverse_combos,axis=1)

In [None]:
mismatch_df = pd.DataFrame()
def check_value(row):
    first_node = row[1]
    second_node = row[4]
    
    global mismatch_df
    
    combination_1 = complete_node_interaction_df[(complete_node_interaction_df['first drug id']==second_node) & (complete_node_interaction_df['second drug id']==first_node) ]['polypharmacy side effect']
    combination_1 = combination_1.tolist()[0]
    
    combination_2 = complete_node_interaction_df[(complete_node_interaction_df['second drug id']==second_node) & (complete_node_interaction_df['first drug id']==first_node) ]['polypharmacy side effect']
    combination_2 = combination_2.tolist()[0]
    
    if combination_1 != combination_2:
        mismatch_df.append(row)
        
complete_node_interaction_df.apply(check_value,axis=1)

### Embedding Visualizations

In [None]:
# Randomly sample 2k training edges 
tsne_edges_train = split_edge['train']['edge']
tsne_edges_train = tsne_edges_train.T
train_edge_tuples = list(zip(tsne_edges_train[0], tsne_edges_train[1]))
tsne_tuples_train = random.sample(train_edge_tuples, 3500)

# Randomly sample 2k test edges 
tsne_edges_test = split_edge['test']['edge']
tsne_edges_test = tsne_edges_test.T
test_edge_tuples = list(zip(tsne_edges_test[0], tsne_edges_test[1]))
tsne_tuples_test = random.sample(test_edge_tuples, 1000)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# If you use GPU, the device should be cuda
print('Device: {}'.format(device))

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import random

In [None]:

def plot_tsne(emb_filepath, emb_model_name, color):
  '''
    Generate 2D tSNE representation of node embeddings as specified in 
    emb_filepath. Generate 3 plots: one with just the node embeddings, another
    with node embeddings and 2k randomly sampled edges from the train set, and
    a third with node embeddings and 2k randomly sampled edges from the test set. 
  '''
  node_emb = torch.load(filepath, map_location='cpu').to(device)
  cpu_emb = node_emb.cpu().data.numpy() # move to cpu, convert to numpy array

  # Apply t-SNE transformation on node embeddings
  tsne = TSNE(n_components=2)
  node_embeddings_2d = tsne.fit_transform(cpu_emb)  


  # Define subplots
  f, axs = plt.subplots(1,3,figsize=(18,5), dpi=80)

  # Plot train set embeddings
  emb_color = '#EF8A5A'
  alpha = 0.2
  axs[0].scatter(
      node_embeddings_2d[:, 0],
      node_embeddings_2d[:, 1],
      s=100,
      c=emb_color,
      alpha=alpha,
  )

  plot_title = f'Node Embeddings'
  axs[0].set_title(plot_title)

  # Plot embeddings with randomly sampled train set edges
  train_color = '#F6B53D'
  axs[1].scatter(
      node_embeddings_2d[:, 0],
      node_embeddings_2d[:, 1],
      s=100,
      c=train_color, 
      alpha=alpha,
  )
  for x, y in tsne_tuples_train:
      i, j = x.item(), y.item()
      x_i, x_j = node_embeddings_2d[i, 0], node_embeddings_2d[j, 0]
      y_i, y_j = node_embeddings_2d[i, 1], node_embeddings_2d[j, 1]
      axs[1].plot([x_i,x_j],[y_i,y_j],'k-', linewidth=0.10)

  plot_title = f'Node Embeddings: Train Edges'
  axs[1].set_title(plot_title)
  # Plot embeddings with randomly sampled test set edges
  test_color = '#15CAB6' #hot pink

  axs[2].scatter(
      node_embeddings_2d[:, 0],
      node_embeddings_2d[:, 1],
      s=100,
      c=test_color, 
      alpha=alpha,
  )
  for x, y in tsne_tuples_test:
      i, j = x.item(), y.item()
      x_i, x_j = node_embeddings_2d[i, 0], node_embeddings_2d[j, 0]
      y_i, y_j = node_embeddings_2d[i, 1], node_embeddings_2d[j, 1]
      axs[2].plot([x_i,x_j],[y_i,y_j],'k-', linewidth=0.10)
  
  plot_title = f'Node Embeddings: Test Edges'
  axs[2].set_title(plot_title)

  sup_title = f't-SNE Visualization of Train and Test Edge Embeddings using GCN'
  f.suptitle(sup_title)
  figure_path = f'tsne_{emb_model_name}.png'
  plt.savefig(figure_path)



In [None]:
model_name = 'graph_sage'
run = 0
filepath = 'C:/Users/yuvas/Documents/MSc Course/Dissertation/OGB - DDI/Output/training_outputs/'+f'{model_name}_final_emb_{run}.pt'
color = '#f794e9'
plot_tsne(filepath, model_name, color)

In [None]:
model_name = 'mf'
run = 0
filepath = 'C:/Users/yuvas/Documents/MSc Course/Dissertation/OGB - DDI/Output/training_outputs/'+f'{model_name}_final_emb_{run}.pt'
color = '#f794e9'
plot_tsne(filepath, model_name, color)

In [None]:
filepath = 'C:/Users/yuvas/Documents/MSc Course/Dissertation/OGB - DDI/Output/training_outputs/' + 'embedding.pt' # Plot tSNE for Node2Vec embeddings
color = '#f794e9'
plot_tsne(filepath, 'Node2Vec_256_dim', color)

In [None]:
model_name = 'gnn'
run = 0
filepath = f'{model_name}_final_emb_{run}.pt'
color = '#f794e9'
plot_tsne(filepath, model_name, color)