# How to Trace/Convert almost any PyTorch Geometric model into Triton acceptable models !!

In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv
import os
from torch_geometric.utils import to_networkx
import networkx as nx
# importing matplotlib.pyplot
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# path to dataset
root = "/home/sachin/Desktop/arangoml/datasets"


In [3]:
dataset = PygNodePropPredDataset('ogbn-products', root)

In [4]:
# getting train val test split idx based on sales ranking
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-products')
data = dataset[0]

In [5]:
test_idx = split_idx['test']


Neighborhood Sampling

In [6]:
test_loader = NeighborSampler(data.edge_index, node_idx=test_idx,
                              sizes=[15, 10, 5], batch_size=1,
                              shuffle=False, num_workers=12)

In [7]:
# selecting random test node and its adjacency matrix
dummy_n_ids = []
dummy_adjs = []
for idx, (batch_size, n_id, adjs) in enumerate(test_loader):
    if idx == 550:
        dummy_n_ids.append(n_id)
        dummy_adjs.append(adjs)
        break

In [9]:
# ids of the node involved in computation
dummy_n_ids[0]

tensor([ 236488, 2383556, 1616861, 1667901, 1785374,   37632,  757171,  258762])

In [10]:
# ajacency list or edge index
dummy_adjs

[[EdgeIndex(edge_index=tensor([[1, 2, 0, 2, 0, 1, 3, 2, 4, 5, 6, 7],
          [0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]]), e_id=tensor([25936595, 92402649, 25936594, 92402650, 92402648, 92402651, 92402647,
          92402646, 94509855,  4828386, 53135952, 27156330]), size=(8, 4)),
  EdgeIndex(edge_index=tensor([[1, 2, 0, 2, 0, 1, 3],
          [0, 0, 1, 1, 2, 2, 2]]), e_id=tensor([25936595, 92402649, 25936594, 92402650, 92402648, 92402651, 92402647]), size=(4, 3)),
  EdgeIndex(edge_index=tensor([[1, 2],
          [0, 0]]), e_id=tensor([25936595, 92402649]), size=(3, 1))]]

len of edge_index is equal to number of number of hops from which we want to extract neighborhood information

In [11]:
len(dummy_adjs[0])

3

In [12]:
# creating adjs for performing a trace on the GraphSage model
# will contain only edge_idx and size attributes
edge_list_0 = []
edge_list_1 = []
edge_list_2 = []
edge_adjs = []
for idx, e_idx in enumerate(dummy_adjs[0]):
    if idx == 0:
        edge_list_0.append(e_idx[0])
        #edge_list_0.append(e_idx[1])
        edge_list_0.append(torch.tensor(np.asarray(e_idx[2])))
    elif idx == 1:
        edge_list_1.append(e_idx[0])
        #edge_list_1.append(e_idx[1])
        edge_list_1.append(torch.tensor(np.asarray(e_idx[2])))
    else:
        edge_list_2.append(e_idx[0])
        #edge_list_2.append(e_idx[1])
        edge_list_2.append(torch.tensor(np.asarray(e_idx[2])))

In [13]:
# moving to cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

edge_index_0 = edge_list_0[0]
edge_index_0 = edge_index_0.to(device)
edge_size_0 = edge_list_0[1]
edge_size_0 = edge_size_0.to(device)

edge_index_1 = edge_list_1[0]
edge_index_1 = edge_index_1.to(device)
edge_size_1 = edge_list_1[1]
edge_size_1 = edge_size_1.to(device)

edge_index_2 = edge_list_2[0]
edge_index_2 = edge_index_2.to(device)
edge_size_2 = edge_list_2[1]
edge_size_2 = edge_size_2.to(device)

In [14]:
# loading node feature matrix of the graph
x = data.x

In [15]:
# total number of nodes involved in the computation graph
x[dummy_n_ids[0]].shape

torch.Size([8, 100])

In [16]:
# lets create node dummy input for the trace
dummy_x = x[dummy_n_ids[0]]
print(dummy_x.shape)

torch.Size([8, 100])


In [17]:
# padding nodes
max_nodes = 1000
total_nodes = dummy_x.size(0)
nodes_padded = max_nodes - total_nodes
dummy_x_pad = F.pad(input=dummy_x, pad=(0, 0, 0, nodes_padded), mode='constant', value=0)
dummy_x_pad = dummy_x_pad.to(device)
print(dummy_x_pad.shape)

torch.Size([1000, 100])


In [18]:
# rest of the rows in dummy_x are filled with 0
dummy_x_pad 

tensor([[-0.3376, -0.3109,  0.2868,  ..., -0.7435,  0.1572, -0.1681],
        [ 0.1358,  0.0866, -0.5094,  ...,  0.8339,  0.3380,  0.6472],
        [-1.4801,  0.6196, -0.1442,  ..., 12.1236, -2.6969, -2.7455],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')

In [19]:
# graph sage
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super(SAGE, self).__init__()

        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
    
    def forward(self, x, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2):
        # `train_loader` computes the k-hop neighborhood of a batch of nodes,
        # and returns, for each layer, a bipartite graph object, holding the
        # bipartite edges `edge_index`, the index `e_id` of the original edges,
        # and the size/shape `size` of the bipartite graph.
        # Target nodes are also included in the source nodes so that one can
        # easily apply skip-connections or add self-loops.
        max_target_nodes = 500
        for i in range(3):
            xs = []
            
            if i == 0:
                edge_index = edge_index_0
                size = edge_size_0
            elif i == 1:
                edge_index = edge_index_1
                size = edge_size_1
            elif i ==2:
                edge_index = edge_index_2
                size = edge_size_2
                
            x_target = x[:size[1]]  # Target nodes are always placed first.
            tar_nodes_padded = max_target_nodes - size[1]
            x_target = F.pad(input=x_target, pad=(0, 0, 0, tar_nodes_padded), mode='constant', value=0)

            x = self.convs[i]((x, x_target), edge_index)
            
            if i != self.num_layers - 1:
                x = F.relu(x)
                #x = F.dropout(x, p=0.5, training=self.training)
            xs.append(x)
            # layer 1 embeddings
            if i == 0: 
                x_all = torch.cat(xs, dim=0)
                layer_1_embeddings = x_all
            # layer 2 embeddings
            elif i == 1:
                x_all = torch.cat(xs, dim=0)
                layer_2_embeddings = x_all
            # layer 3 embeddings
            elif i == 2:
                x_all = torch.cat(xs, dim=0)
                layer_3_embeddings = x_all    
        #return x.log_softmax(dim=-1)
        return layer_1_embeddings, layer_2_embeddings, layer_3_embeddings

In [20]:
# import model and chechkpoint
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGE(dataset.num_features, 256, dataset.num_classes)
model = model.to(device)

In [21]:
# loading checkpont
model_checkpoint = '/home/sachin/Desktop/arangoml/obgn_wts/weight_checkpoint.pth.tar'
model_w = torch.load(model_checkpoint)
model_w = model_w["state_dict"]
model.load_state_dict(model_w)

<All keys matched successfully>

# Tracing PyTorch Geometric GraphSage Model

Conversion of the model is done using its JIT traced version. According to PyTorch’s documentation: ‘Torchscript’ is a way to create serializable and optimizable models from PyTorch code”.

It allows the developer to export their model to be re-used in other programs, such as efficiency-oriented C++ programs. Exporting a model requires: Dummy inputs and Standard length to execute the model’s forward pass.

During the model’s forward pass with dummy inputs, PyTorch keeps the track of different operations on each tensor and records these operations to create the “trace” of the model.

Since the created trace is relative to the dummy input dimensions, therefore the model inputs in the future will be constrained by the dimension of the dummy input, and will not work for other sequences length or batch size.

It is therefore recommended to trace the model with the largest dummy input dimension that you can think can be fed to the model in the future.

In [22]:
class PyTorch_to_TorchScript(torch.nn.Module):
    def __init__(self):
        super(PyTorch_to_TorchScript, self).__init__()
        self.model = model
    def forward(self, data, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2):
        return self.model(data, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2)

In [23]:
# after trace it will save the model in cwd
pt_model = PyTorch_to_TorchScript().eval()
#pt_model = pt_model.to(device)

In [24]:
# trace
traced_script_module = torch.jit.trace(pt_model, (dummy_x_pad, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2), strict=False)

In [25]:
# saving the traced model in cwd
traced_script_module.save("./model.pt")

## Loading traced model

In [26]:
tr_model = torch.jit.load("./model.pt")

In [27]:
tr_out = tr_model(dummy_x_pad, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2)

In [28]:
# layer-1, layer-2, layer-3 embeddings
tr_out[0].shape, tr_out[1].shape, tr_out[2].shape

(torch.Size([500, 256]), torch.Size([500, 256]), torch.Size([500, 47]))

In [29]:
tr_out[2].shape

torch.Size([500, 47])

In [30]:
tr_out[2][0]

tensor([-8.6776e+00, -9.1641e+00, -1.4251e+01, -4.8633e+00,  2.7073e+01,
        -9.1761e+00,  4.3235e-02,  1.4600e+01, -2.1659e+00, -8.8615e+00,
        -1.1793e+01, -1.3701e+01, -4.8658e+00,  5.9286e+00, -6.9220e+00,
        -9.5293e+00, -1.7088e+01, -1.1222e+01,  8.1884e-01, -1.5866e+01,
         5.5329e-01, -1.4108e+01, -1.2657e+01, -1.9597e+01,  8.1923e+00,
        -1.5794e+01, -2.2992e+01, -2.2723e+01, -4.6216e+01, -5.1885e+01,
        -2.5836e+01, -1.6740e+01, -1.7584e+01, -1.2375e+01, -3.3004e+01,
        -1.5356e+01, -3.4252e+01, -2.4260e+01, -1.7691e+01, -2.4234e+01,
        -3.3041e+01, -1.1956e+01, -2.2771e+01, -2.2491e+01, -2.1991e+01,
        -2.2300e+01, -2.2732e+01], device='cuda:0', grad_fn=<SelectBackward>)

# Writing the Model Configuration File

This configuration file, config.pbtxt contains the detail of permissible input/outputs types and shapes, favorable batch sizes, versioning, platform since the server doesn't know details about these configurations, therefore, we write them into a separate configuration file.

Configuration for the above GraphSage Model

name: "graph_embeddings"

platform: "pytorch_libtorch"

input [
 {
    name: "input__0"
    data_type: TYPE_FP32
    dims: [1000, 100]
  } ,
  
{
    name: "input__1"
    data_type: TYPE_INT64
    dims: [2, -1]
  },
  
{
    name: "input__2"
    data_type: TYPE_INT64
    dims: [2]
  },
  
{
    name: "input__3"
    data_type: TYPE_INT64
    dims: [2, -1]
  },
  
{
    name: "input__4"
    data_type: TYPE_INT64
    dims: [2]
  },
  
{
    name: "input__5"
    data_type: TYPE_INT64
    dims: [2, -1]
  },
  
{
    name: "input__6"
    data_type: TYPE_INT64
    dims: [2]
  }
  
]

output [
{
    name: "output__0"
    data_type: TYPE_FP32
    dims: [500, 256]
  },
  
{
    name: "output__1"
    data_type: TYPE_FP32
    dims: [500, 256]
  },
  
{
    name: "output__2"
    data_type: TYPE_FP32
    dims: [500, 47]
  }
  
]