In [1]:
from rrgcn import RRGCNEmbedder
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred import Evaluator
from torch import nn
from tqdm import tqdm
import copy
from sklearn.preprocessing import StandardScaler

In [2]:
#!pip install ogb
dataset = PygNodePropPredDataset(name="ogbn-mag")

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = (
    split_idx["train"],
    split_idx["valid"],
    split_idx["test"],
)
graph = dataset[0]

In [3]:
graph

Data(
  num_nodes_dict={
    author=1134649,
    field_of_study=59965,
    institution=8740,
    paper=736389
  },
  edge_index_dict={
    (author, affiliated_with, institution)=[2, 1043998],
    (author, writes, paper)=[2, 7145660],
    (paper, cites, paper)=[2, 5416271],
    (paper, has_topic, field_of_study)=[2, 7505078]
  },
  x_dict={ paper=[736389, 128] },
  node_year={ paper=[736389, 1] },
  edge_reltype={
    (author, affiliated_with, institution)=[1043998, 1],
    (author, writes, paper)=[7145660, 1],
    (paper, cites, paper)=[5416271, 1],
    (paper, has_topic, field_of_study)=[7505078, 1]
  },
  y_dict={ paper=[736389, 1] }
)

In [4]:
edge_indices = []
edge_types = []

node_types = set()
for k, v in graph.edge_index_dict.items():
    node_types = node_types.union({k[0], k[2]})

node_types = sorted(node_types)
num_node_types = len(node_types)
node_type_to_add = {n: i for i, n in enumerate(node_types)}

for k, v in graph.edge_index_dict.items():
    edge_indices.append(
        torch.vstack(
            (
                ((v[0] * num_node_types) + node_type_to_add[k[0]]),
                ((v[1] * num_node_types) + node_type_to_add[k[2]]),
            )
        )
    )
    edge_types.append(graph.edge_reltype[k])

edge_index = torch.hstack(edge_indices)
edge_type = torch.vstack(edge_types).squeeze()
num_nodes = sum(graph.num_nodes_dict.values())
num_rels = edge_type.unique().numel()

# convert node idx to consecutive integers
node_idx = torch.full((edge_index.max() + 1,), -1)
node_idx[edge_index.unique()] = torch.arange(num_nodes)
edge_index = node_idx[edge_index]
assert num_nodes == edge_index.max() + 1

# inverses
edge_type = torch.hstack((2 * edge_type, (2 * edge_type) + 1))
edge_index = torch.hstack((edge_index, edge_index[[1, 0]]))

node_features = {}
for i, (k, word_feat) in enumerate(graph.x_dict.items()):
    node_features[i] = [
        node_idx[
            (torch.arange(word_feat.shape[0]) * num_node_types) + node_type_to_add[k]
        ],
        word_feat
    ]

train_ys = []
valid_ys = []
test_ys = []

for k, v in graph.y_dict.items():
    train_ys.append(v[train_idx[k]])
    valid_ys.append(v[valid_idx[k]])
    test_ys.append(v[test_idx[k]])

train_y = torch.hstack(train_ys)
valid_y = torch.hstack(valid_ys)
test_y = torch.hstack(test_ys)

train_idxs = []
valid_idxs = []
test_idxs = []

for k in train_idx.keys():
    train_idxs.append(node_idx[(train_idx[k] * num_node_types) + node_type_to_add[k]])
    valid_idxs.append(node_idx[(valid_idx[k] * num_node_types) + node_type_to_add[k]])
    test_idxs.append(node_idx[(test_idx[k] * num_node_types) + node_type_to_add[k]])

train_idx = torch.hstack(train_idxs)
valid_idx = torch.hstack(valid_idxs)
test_idx = torch.hstack(test_idxs)



In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
embedder = RRGCNEmbedder(
    num_nodes=num_nodes,
    num_relations=num_rels,
    num_layers=2,
    emb_size=750,
    device=device,
)

In [6]:
# for node features to work well, they have to be normalized
# you can choose "standard" for StandardScaler, "robust" for RobustScaler, "quantile"
# for QuantileTransformer and "power" for PowerTransformer
#
# you could also pass sklearn compatible scalers by passing a dict keyed by
# literal type, e.g.:
# {0: StandardScaler(), 1: RobustScaler()}

train_embs = embedder.embeddings(
    edge_index,
    edge_type,
    node_features=node_features,
    node_features_scalers="standard",
    idx=train_idx,
)


100%|██████████| 1/1 [01:23<00:00, 83.28s/it]


In [7]:
# only fit node feature scalers on nodes reachable from train nodes,
# for val and test nodes, reuse the fit scalers using embedder.get_last_fit_scalers()
val_embs = embedder.embeddings(
    edge_index,
    edge_type,
    node_features=node_features,
    node_features_scalers=embedder.get_last_fit_scalers(),
    idx=valid_idx,
)
test_embs = embedder.embeddings(
    edge_index,
    edge_type,
    node_features=node_features,
    node_features_scalers=embedder.get_last_fit_scalers(),
    idx=test_idx,
)

100%|██████████| 1/1 [00:56<00:00, 56.14s/it]
100%|██████████| 1/1 [00:52<00:00, 52.81s/it]


In [20]:
scaler = StandardScaler()
train_embs_scaled = torch.tensor(scaler.fit_transform(train_embs), dtype=torch.float32)
test_embs_scaled = torch.tensor(scaler.transform(test_embs), dtype=torch.float32)
val_embs_scaled = torch.tensor(scaler.transform(val_embs), dtype=torch.float32)

In [30]:
# batched training is a mess in CatBoost, so use MLP for this trainset
train_set = torch.utils.data.TensorDataset(train_embs_scaled, train_y)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10_000, shuffle=True)

mlp = nn.Sequential(
    nn.Linear(train_embs.shape[1], train_embs.shape[1] // 2),
    nn.Dropout(0.3),
    nn.ReLU(),
    nn.Linear(train_embs.shape[1] // 2, train_embs.shape[1] // 2),
    nn.Dropout(0.1),
    nn.ReLU(),
    nn.Linear(train_embs.shape[1] // 2, train_y.unique().numel() * 2),
    nn.ReLU(),
    nn.Linear(train_y.unique().numel() * 2, train_y.unique().numel()),
)
mlp = mlp.to(device)

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
num_epochs = 100
early_stopping_epochs = 20
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=early_stopping_epochs // 2,
    verbose=True
)

best_val_loss = 1_000_000
best_model = None
epochs_since_best = 0
for epoch in range(num_epochs):
    mlp.train()
    with tqdm(
        train_loader, total=len(train_loader), desc=f"Epochs {epoch + 1}/{num_epochs}"
    ) as bar:
        total_loss = 0
        total_items = 0
        for i, (x, y) in enumerate(bar):
            optimizer.zero_grad()

            out = mlp(x.to(device))
            l = loss(out, y.squeeze().to(device))
            total_loss += (x.shape[0]) * l.item()
            total_items += x.shape[0]

            l.backward()
            optimizer.step()
            bar.set_postfix(loss=l.item())

            if i == len(train_loader) - 1:
                mlp.eval()
                val_loss = loss(
                    mlp(val_embs_scaled.to(device)), valid_y.squeeze().to(device)
                )

                if val_loss.item() < best_val_loss:
                    best_model = copy.deepcopy(mlp)
                    best_val_loss = val_loss.item()
                    epochs_since_best = 0
                else:
                    epochs_since_best += 1

                bar.set_postfix(
                    total_loss=total_loss / total_items, val_loss=val_loss.item()
                )
                scheduler.step(val_loss)

    if epochs_since_best == early_stopping_epochs:
        mlp = best_model
        print(
            "Early stopping, resetting weights to best model"
            + f"with val_loss {best_val_loss}"
        )
        break


Epochs 1/100: 100%|██████████| 63/63 [00:14<00:00,  4.28it/s, total_loss=3.32, val_loss=2.65]
Epochs 2/100: 100%|██████████| 63/63 [00:14<00:00,  4.40it/s, total_loss=2.47, val_loss=2.45]
Epochs 3/100: 100%|██████████| 63/63 [00:14<00:00,  4.38it/s, total_loss=2.29, val_loss=2.38]
Epochs 4/100: 100%|██████████| 63/63 [00:14<00:00,  4.34it/s, total_loss=2.19, val_loss=2.29]
Epochs 5/100: 100%|██████████| 63/63 [00:14<00:00,  4.25it/s, total_loss=2.12, val_loss=2.25]
Epochs 6/100: 100%|██████████| 63/63 [00:14<00:00,  4.31it/s, total_loss=2.05, val_loss=2.23]
Epochs 7/100: 100%|██████████| 63/63 [00:14<00:00,  4.36it/s, total_loss=2, val_loss=2.2]
Epochs 8/100: 100%|██████████| 63/63 [00:14<00:00,  4.36it/s, total_loss=1.96, val_loss=2.19]
Epochs 9/100: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s, total_loss=1.92, val_loss=2.18]
Epochs 10/100: 100%|██████████| 63/63 [00:14<00:00,  4.23it/s, total_loss=1.88, val_loss=2.16]
Epochs 11/100: 100%|██████████| 63/63 [00:14<00:00,  4.35it/s, 

Epoch 00029: reducing learning rate of group 0 to 1.0000e-04.


Epochs 30/100: 100%|██████████| 63/63 [00:13<00:00,  4.51it/s, total_loss=1.38, val_loss=2.14]
Epochs 31/100: 100%|██████████| 63/63 [00:13<00:00,  4.69it/s, total_loss=1.35, val_loss=2.14]
Epochs 32/100: 100%|██████████| 63/63 [00:14<00:00,  4.27it/s, total_loss=1.34, val_loss=2.14]
Epochs 33/100: 100%|██████████| 63/63 [00:13<00:00,  4.52it/s, total_loss=1.34, val_loss=2.14]
Epochs 34/100: 100%|██████████| 63/63 [00:15<00:00,  4.19it/s, total_loss=1.33, val_loss=2.14]
Epochs 35/100: 100%|██████████| 63/63 [00:14<00:00,  4.22it/s, total_loss=1.32, val_loss=2.15]
Epochs 36/100: 100%|██████████| 63/63 [00:14<00:00,  4.31it/s, total_loss=1.32, val_loss=2.15]
Epochs 37/100: 100%|██████████| 63/63 [00:13<00:00,  4.61it/s, total_loss=1.32, val_loss=2.14]
Epochs 38/100: 100%|██████████| 63/63 [00:14<00:00,  4.26it/s, total_loss=1.31, val_loss=2.15]

Early stopping, resetting weights to best modelwith val_loss 2.0962438583374023





In [33]:
mlp.eval()
evaluator = Evaluator(name="ogbn-mag")

print("Test:")
evaluator.eval(
    {
        "y_true": test_y.cpu().numpy(),
        "y_pred": mlp(test_embs_scaled.to(device))
        .argmax(-1)
        .reshape(-1, 1)
        .detach()
        .cpu(),
    }
)

Test:


{'acc': 0.396027563842724}

In [34]:
print("Validation:")
evaluator.eval(
    {
        "y_true": valid_y.cpu().numpy(),
        "y_pred": mlp(val_embs_scaled.to(device))
        .argmax(-1)
        .reshape(-1, 1)
        .detach()
        .cpu(),
    }
)

Validation:


{'acc': 0.4113811865164383}

In [None]:
# Trained full-batch R-GCN with input node features on the same dataset
# https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/rgcn.py
# https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-mag
# -------------------------------------
#   Test accuracy     Valid accuracy 
# -------------------------------------
#   0.3977 ± 0.0046   0.4084 ± 0.0041
# -------------------------------------