In [3]:
from torch_geometric.data import HeteroData
import torch

data = HeteroData()

data['paper'].x = []
data['author'].x = []
data['institution'].x = []
data['field_of_study'].x = []

data['paper', 'cites', 'paper'].edge_index = []
data['author', 'write', 'paper'].edge_index = []
data['author', 'affiliate', 'institution'].edge_index = []
data['paper', 'has_topic', 'field_of_study'].edge_index = []

data['paper', 'cites', 'paper'].edge_attr = []
data['author', 'write', 'paper'].edge_attr = []
data['author', 'affiliate', 'institution'].edge_attr = []
data['paper', 'has_topic', 'field_of_study'].edge_attr = []



In [None]:
from torch_geometric.datasets import OGB_MAG
dataset = OGB_MAG(root='../../data/ogb_mag', preprocess='metapath2vec')
data = dataset[0]

In [4]:
from torch_geometric.nn import SAGEConv, to_hetero
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x,  edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

In [None]:
from torch_geometric.nn import GATConv, Linear, to_hetero
dataset = dataset()
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(-1, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x
    
model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')



In [None]:
import torch_geometric.transforms as T 
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear

dataset = OGB_MAG(root = '../../data/ogb_mag/', preprocess='metapath2vec', transform=T.ToUndirected())
data = dataset[0]

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()
        
        self.convs = torch.nn.ModuleList()

        for _ in range(num_layers):
            conv = HeteroConv({
                ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
                ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
                ('paper', 're_writes', 'author'): GATConv((-1, -1), hidden_channels),
            }, aggr='sum')
        
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)


    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.lin(x_dict['author'])
    
model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes, num_layers=2)