In [4]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import ReLU
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import HGTLoader, NeighborLoader
from torch_geometric.nn import Linear, SAGEConv, Sequential, to_hetero

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.ToUndirected(merge=True)
dataset = OGB_MAG("/tmp/", preprocess='metapath2vec', transform=transform)

# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')

train_input_nodes = ('paper', data['paper'].train_mask)
val_input_nodes = ('paper', data['paper'].val_mask)
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}

train_loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True,
                              input_nodes=train_input_nodes, **kwargs)
val_loader = NeighborLoader(data, num_neighbors=[10] * 2,
                            input_nodes=val_input_nodes, **kwargs)

model = Sequential('x, edge_index', [
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (Linear(-1, dataset.num_classes), 'x -> x'),
])
model = to_hetero(model, data.metadata(), aggr='sum').to(device)


@torch.no_grad()
def init_params():
    # Initialize lazy parameters via forwarding a single batch to the model:
    batch = next(iter(train_loader))
    batch = batch.to(device, 'edge_index')
    model(batch.x_dict, batch.edge_index_dict)


def train():
    model.train()

    total_examples = total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        batch = batch.to(device, 'edge_index')
        batch_size = batch['paper'].batch_size

        for i in batch.edge_index_dict.items(): print(i)

        stop

        out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
        loss = F.cross_entropy(out, batch['paper'].y[:batch_size])
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples


@torch.no_grad()
def test(loader):
    model.eval()

    total_examples = total_correct = 0
    for batch in tqdm(loader):
        batch = batch.to(device, 'edge_index')
        batch_size = batch['paper'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
        pred = out.argmax(dim=-1)

        total_examples += batch_size
        total_correct += int((pred == batch['paper'].y[:batch_size]).sum())

    return total_correct / total_examples


init_params()  # Initialize parameters.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 21):
    loss = train()
    val_acc = test(val_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')

  0%|          | 0/615 [00:00<?, ?it/s]

(('author', 'affiliated_with', 'institution'), tensor([], device='cuda:0', size=(2, 0), dtype=torch.int64))
(('author', 'writes', 'paper'), tensor([[    0,     1,     2,  ..., 27543,  4487,  4489],
        [    0,     0,     0,  ...,  7859,  7859,  7859]], device='cuda:0'))
(('paper', 'cites', 'paper'), tensor([[ 1024,  1025,  1026,  ..., 50302, 50303, 50304],
        [    0,     0,     0,  ...,  7859,  7859,  7859]], device='cuda:0'))
(('paper', 'has_topic', 'field_of_study'), tensor([[50305, 50306, 50307,  ..., 80434, 80435, 80436],
        [    0,     0,     0,  ...,  3462,  3462,  3462]], device='cuda:0'))
(('institution', 'rev_affiliated_with', 'author'), tensor([[   0,    1,    0,  ...,  133,  133,   25],
        [   0,    1,    1,  ..., 4488, 4489, 4489]], device='cuda:0'))
(('paper', 'rev_writes', 'author'), tensor([[80437, 80438, 80439,  ..., 95202, 50273,  7859],
        [    0,     0,     0,  ...,  4489,  4489,  4489]], device='cuda:0'))
(('field_of_study', 'rev_has_topic', 




NameError: name 'stop' is not defined