In [16]:
import torch
import torch.nn as nn

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

In [2]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, num_embedding, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()
        
        self.emb = nn.Embedding()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

dataset = PygLinkPropPredDataset(name='ogbl-ddi', root='../dataset')
data = dataset[0]
split_edge = dataset.get_edge_split()

In [17]:
split_edge

{'train': {'edge': tensor([[4039, 2424],
          [4039,  225],
          [4039, 3901],
          ...,
          [ 647,  708],
          [ 708,  338],
          [ 835, 3554]])},
 'valid': {'edge': tensor([[ 722,  548],
          [ 874, 3436],
          [ 838, 1587],
          ...,
          [3661, 3125],
          [3272, 3330],
          [1330,  776]]),
  'edge_neg': tensor([[   0,   58],
          [   0,   84],
          [   0,   90],
          ...,
          [4162, 4180],
          [4168, 4260],
          [4180, 4221]])},
 'test': {'edge': tensor([[2198, 1172],
          [1205,  719],
          [1818, 2866],
          ...,
          [ 326, 1109],
          [ 911, 1250],
          [4127, 2480]]),
  'edge_neg': tensor([[   0,    2],
          [   0,   16],
          [   0,   42],
          ...,
          [4168, 4259],
          [4208, 4245],
          [4245, 4259]])}}