# Demo: GRPE Advanced

This demo notebook showcases the easily configurable advanced GRPE which allows customizeable attention biases to be enforced as well as includes the node degree centrality encoding. The option for perturbing node features using uniform distribution is also provided.

In [None]:
import os
import sys
from tqdm import tqdm
import numpy
import torch
import torch.nn
from torch_geometric.data import Data, Batch
from torchvision.transforms import Compose
from torch.utils.data.dataloader import DataLoader
from typing import Dict, List, Any
sys.path.insert(0, '/home/shayan/phoenix/graphite/')
from graphite.data.pcqm4mv2.pyg import PCQM4Mv2Dataset
from graphite.data.pcqm4mv2.pyg.collator import collate_fn, default_collate_fn
from graphite.data.pcqm4mv2.pyg.transforms import AddTaskNode, EncodeNode2NodeConnectionType,  EncodeNode2NodeShortestPathLengthType, \
EncodeNode2NodeShortestPathFeatureTrajectory, EncodeNodeType, EncodeEdgeType, EncodeNodeDegreeCentrality, ComenetEdgeFeatures
from graphite.cortex.model.grpe import GraphRelativePositionalEncodingNetwork
from graphite.cortex.model.grpe_advanced import GraphRelativePositionalEncodingNetworkAdvanced
from graphite.utilities.device import move_batch_to_device
from graphite.utilities.miscellaneous import count_parameters

In [2]:
device = torch.device('cuda:1')
criterion = torch.nn.L1Loss()

In [3]:
dataset = PCQM4Mv2Dataset(root='/home/shayan/from_source/GRPE/data', transform=Compose([
    EncodeNode2NodeShortestPathFeatureTrajectory(max_length_considered=4, feature_position_offset=4),
    EncodeNodeType(),
    EncodeNodeDegreeCentrality(),
    AddTaskNode(),
    EncodeEdgeType(),
    EncodeNode2NodeConnectionType(),
    EncodeNode2NodeShortestPathLengthType(max_length_considered=5)
]))

split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx["train"]]
valid_dataset = dataset[split_idx["valid"]]

test_dataset = dataset[split_idx["test-dev"]]

train_sampler = torch.utils.data.RandomSampler(train_dataset)
valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)

dataloader_args=dict(batch_size=128, collate_fn=collate_fn, num_workers=4, pin_memory=True)

train_dataloader = DataLoader(train_dataset, sampler=train_sampler, **dataloader_args)
val_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, **dataloader_args)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, **dataloader_args)

In [4]:
batch = next(iter(train_dataloader))
list(batch.keys())

['node_type',
 'node_features',
 'edge_type',
 'graphs',
 'y',
 'node2node_shortest_path_length_type',
 'node2node_connection_type',
 'shortest_path_feature_trajectory',
 'node_degree_centrality']

In [41]:
model = torch.nn.Sequential(
    GraphRelativePositionalEncodingNetworkAdvanced(
        model_dimension=768,
        number_of_heads=32,
        number_of_layers=12,
        feedforward_dimension=768,
        dropout=0.1,
        attention_dropout=0.1,
        shortest_path_length_upperbound=5,
        perturbation=0.0,
        independent_layer_embeddings=False,
        attention_biases=[
            'edge',
            'shortest_path_length',
            'shortest_path'
        ],
        path_encoding_length_upperbound=4,
        path_encoding_code_dim=4,
        encode_node_degree_centrality=True
    ),
    torch.nn.LayerNorm(768),
    torch.nn.Linear(768, 1)
).to(device)

In [42]:
count_parameters(model)

42954105

In [31]:
from graphite.utilities.miscellaneous import count_parameters

In [9]:
%%time # on GeForce RTX 2080 - 11019Mib memory
max_iters = 1000
for i, batch in tqdm(enumerate(train_dataloader)):
    if i > max_iters:
        break
    batch = move_batch_to_device(batch, device)
    graph_reps = model(batch)
    loss = criterion(graph_reps.squeeze(), batch['y'])
    loss.backward()
    model.zero_grad()

1001it [04:48,  3.47it/s]


CPU times: user 1h 11min 33s, sys: 1min 19s, total: 1h 12min 53s
Wall time: 4min 49s
