In [1]:
import torch_geometric
import torch
from torch_geometric.datasets import OGB_MAG
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborLoader

In [2]:
print(torch.__version__)
print(torch_geometric.__version__)

1.8.2
2.0.1


In [3]:
dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected(merge=True))
data = dataset[0]
print(data)

HeteroData(
  [1mpaper[0m={
    x=[736389, 128],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  [1mauthor[0m={ x=[1134649, 128] },
  [1minstitution[0m={ x=[8740, 128] },
  [1mfield_of_study[0m={ x=[59965, 128] },
  [1m(author, affiliated_with, institution)[0m={ edge_index=[2, 1043998] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 7145660] },
  [1m(paper, cites, paper)[0m={ edge_index=[2, 10792672] },
  [1m(paper, has_topic, field_of_study)[0m={ edge_index=[2, 7505078] },
  [1m(institution, rev_affiliated_with, author)[0m={ edge_index=[2, 1043998] },
  [1m(paper, rev_writes, author)[0m={ edge_index=[2, 7145660] },
  [1m(field_of_study, rev_has_topic, paper)[0m={ edge_index=[2, 7505078] }
)


In [4]:
data.metadata()

(['paper', 'author', 'institution', 'field_of_study'],
 [('author', 'affiliated_with', 'institution'),
  ('author', 'writes', 'paper'),
  ('paper', 'cites', 'paper'),
  ('paper', 'has_topic', 'field_of_study'),
  ('institution', 'rev_affiliated_with', 'author'),
  ('paper', 'rev_writes', 'author'),
  ('field_of_study', 'rev_has_topic', 'paper')])

In [5]:
train_loader = NeighborLoader(data, num_neighbors=[15, 15], batch_size=128, input_nodes=('paper', data['paper'].train_mask), shuffle=True)

In [6]:
print(train_loader)
batch = next(iter(train_loader))
print(batch)

NeighborLoader()
HeteroData(
  [1mpaper[0m={
    x=[23023, 128],
    y=[23023],
    train_mask=[23023],
    val_mask=[23023],
    test_mask=[23023],
    batch_size=128
  },
  [1mauthor[0m={ x=[4438, 128] },
  [1minstitution[0m={ x=[322, 128] },
  [1mfield_of_study[0m={ x=[2977, 128] },
  [1m(author, affiliated_with, institution)[0m={ edge_index=[2, 0] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 5669] },
  [1m(paper, cites, paper)[0m={ edge_index=[2, 13465] },
  [1m(paper, has_topic, field_of_study)[0m={ edge_index=[2, 11363] },
  [1m(institution, rev_affiliated_with, author)[0m={ edge_index=[2, 735] },
  [1m(paper, rev_writes, author)[0m={ edge_index=[2, 4610] },
  [1m(field_of_study, rev_has_topic, paper)[0m={ edge_index=[2, 11943] }
)


In [21]:
from torch_geometric.nn import TransformerConv, GCNConv, GATConv, SAGEConv, to_hetero, Linear, HeteroConv
import torch.nn as nn
import torch.nn.functional as F 

In [19]:
class Net1(torch.nn.Module):
    def __init__(self, hidden_dim, num_classes, num_layers=2) -> None:
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_layers = num_layers

        self.convs = nn.ModuleList()
        self.lins = nn.ModuleList()
        self.bns = nn.ModuleList()

        for i in range(self.num_layers):
            self.convs.append(GATConv((-1, -1), hidden_dim))
            self.lins.append(Linear(-1, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        # self.dropout = torch.nn.Dropout()
        self.fc_out = Linear(-1, num_classes)

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index) + self.lins[i](x)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc_out(x)
        return x

In [31]:
class Net2(nn.Module):
    def __init__(self, hidden_dim, num_classes, num_layers=2):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_layers = num_layers

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        for i in range(self.num_layers):
            self.convs.append(HeteroConv({
                ('author', 'affiliated_with', 'institution'): SAGEConv((-1, -1), hidden_dim),
                ('author', 'writes', 'paper'): GATConv((-1, -1), hidden_dim),
                ('paper', 'cites', 'paper'): TransformerConv((-1, -1), hidden_dim),
                ('paper', 'has_topic', 'field_of_study'): GATConv((-1, -1), hidden_dim),
                ('institution', 'rev_affiliated_with', 'author'): SAGEConv((-1, -1), hidden_dim),
                ('paper', 'rev_writes', 'author'): SAGEConv((-1, -1), hidden_dim),
                ('field_of_study', 'rev_has_topic', 'paper'): SAGEConv((-1, -1), hidden_dim)
            }, aggr='sum'))
            # self.bns.append(nn.)

        self.fc_out = Linear(-1, num_classes)

    def forward(self, x_dict, edge_index_dict):
        for i in range(self.num_layers):
            x_dict = self.convs[i](x_dict, edge_index_dict)
            x_dict = {key: nn.BatchNorm1d(x.shape[1])(x) for key, x in x_dict.items()}    
            x_dict = {key: F.relu(x) for key, x in x_dict.items()}
            x_dict = {key: F.dropout(x, p=0.5, training=self.training) for key, x in x_dict.items()}    
        
        return self.fc_out(x_dict['paper'])

In [33]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = Net1(hidden_dim=64, num_classes=dataset.num_classes, num_layers=2)
model = to_hetero(model, data.metadata(), aggr='sum').to(device)
print(model)

cuda
GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (author__affiliated_with__institution): GATConv((-1, -1), 64, heads=1)
      (author__writes__paper): GATConv((-1, -1), 64, heads=1)
      (paper__cites__paper): GATConv((-1, -1), 64, heads=1)
      (paper__has_topic__field_of_study): GATConv((-1, -1), 64, heads=1)
      (institution__rev_affiliated_with__author): GATConv((-1, -1), 64, heads=1)
      (paper__rev_writes__author): GATConv((-1, -1), 64, heads=1)
      (field_of_study__rev_has_topic__paper): GATConv((-1, -1), 64, heads=1)
    )
    (1): ModuleDict(
      (author__affiliated_with__institution): GATConv((-1, -1), 64, heads=1)
      (author__writes__paper): GATConv((-1, -1), 64, heads=1)
      (paper__cites__paper): GATConv((-1, -1), 64, heads=1)
      (paper__has_topic__field_of_study): GATConv((-1, -1), 64, heads=1)
      (institution__rev_affiliated_with__author): GATConv((-1, -1), 64, heads=1)
      (paper__rev_writes__author): GATConv((-1, -1), 64, heads

In [36]:
model2 = Net2(hidden_dim=64, num_classes=dataset.num_classes, num_layers=2)
model2.to(device)
print(model2)

Net2(
  (convs): ModuleList(
    (0): HeteroConv(num_relations=7)
    (1): HeteroConv(num_relations=7)
  )
  (bns): ModuleList()
  (fc_out): Linear(-1, 349, bias=True)
)


In [38]:
model2.convs[0]

HeteroConv(num_relations=7)

In [17]:
@torch.no_grad()
def init_params():
    # Initialize lazy parameters via forwarding a single batch to the model:
    batch = next(iter(train_loader))
    batch = batch.to(device)
    model(batch.x_dict, batch.edge_index_dict)

In [None]:
def train():
    model.train()
    tot_loss = 0
    tot_correct