In [4]:
!pip install ogb
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"

!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html
!pip install torch-geometric

!pip install -q git+https://github.com/snap-stanford/deepsnap.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ogb
  Downloading ogb-1.3.5-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 KB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Collecting outdated>=0.2.0
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting littleutils
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Created wheel for littleutils: filename=littleutils-0.2.2-py3-none-any.whl size=7047 sha256=3e75e954fa4cc453c9f0cd5281fd6c88ca3c56e098d51714117ba58bdee349c4
  Stored in directory: /root/.cache/pip/wheels/6a/33/c4/0ef84d7f5568c2823e3d63a6e08988852fb9e4bc822034870a
Successfully built littleutils
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for deepsnap (setup.py) ... [?25l[?25hdone


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
from ogb.linkproppred import LinkPropPredDataset
from deepsnap.dataset import GraphDataset
from deepsnap.hetero_graph import HeteroGraph
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepsnap.hetero_gnn import forward_op, HeteroConv
import numpy as np
import copy

In [10]:
dataset = LinkPropPredDataset(name = "ogbl-biokg", root = '/content/drive/MyDrive/biodataset/')
graph = dataset[0]
split_edge = dataset.get_edge_split()

split_edge.keys()

dict_keys(['train', 'valid', 'test'])

In [11]:
# Construct relation types and node types
relation_mapping = {}
relation_inv_mapping = {}
node_types = {}

for rel in graph['edge_reltype']:
  k = graph['edge_reltype'][rel][0][0]

  if rel[0] not in node_types:
    node_types[rel[0]] = len(node_types)

  if rel[2] not in node_types:
    node_types[rel[2]] = len(node_types)

  if k not in relation_mapping:
    relation_mapping[k] = rel
    relation_inv_mapping[rel] = k

print("number of relations : ", len(relation_mapping))
print("number of node types : ", len(node_types))

number of relations :  51
number of node types :  5


In [12]:
G = nx.MultiDiGraph()
node_ids = {}
node_id_cnt = 0

def process_graph(p_node_ids, p_node_types, p_relation):

    src_feats = torch.zeros(len(node_types))
    src_feats[node_types[p_node_types[0]]] = 1.0

    dst_feats = torch.zeros(len(node_types))
    dst_feats[node_types[p_node_types[1]]] = 1.0

    relation_idx = relation_inv_mapping[p_relation]
    rel_feats = torch.zeros(len(relation_mapping))
    rel_feats[relation_idx] = 1.0

    if p_node_ids[0] not in G:
      G.add_node(p_node_ids[0], node_type=p_node_types[0], node_feature=src_feats)
      
    if p_node_ids[1] not in G:
      G.add_node(p_node_ids[1], node_type=p_node_types[1], node_feature=dst_feats)
    
    G.add_edge(p_node_ids[0], p_node_ids[1], edge_type=p_relation[1], edge_feature=rel_feats)


def process_dataset_splits(split_type):
  '''
  Input is split_edge & split_type
  '''
  global node_id_cnt
  n_triplets = len(split_edge[split_type]['head'])

  for idx in range(n_triplets):

    src, dst = split_edge[split_type]['head'][idx], split_edge[split_type]['tail'][idx]
    src_node_type, dst_node_type = split_edge[split_type]['head_type'][idx], split_edge[split_type]['tail_type'][idx]
    msg = relation_mapping[split_edge[split_type]['relation'][idx]]
    src_node_key = src_node_type + "_" + str(src)
    dst_node_key = dst_node_type + "_" + str(dst)


    if src_node_key not in node_ids:
        node_ids[src_node_key] = node_id_cnt
        node_id_cnt += 1

    if dst_node_key not in node_ids:
      node_ids[dst_node_key] = node_id_cnt
      node_id_cnt += 1

    process_graph((node_ids[src_node_key], node_ids[dst_node_key]), \
                  (src_node_type, dst_node_type), msg)

process_dataset_splits('train')
print("Graph node size: ", len(G))
print("num of edges: ", G.number_of_edges())

process_dataset_splits('valid')
print("Graph node size: ", len(G))
print("num of edges: ", G.number_of_edges())

process_dataset_splits('test')
print("Graph node size: ", len(G))
print("num of edges: ", G.number_of_edges())

Graph node size:  93773
num of edges:  4762678
Graph node size:  93773
num of edges:  4925564
Graph node size:  93773
num of edges:  5088434


In [13]:
temp = {}
for n, attrs in G.nodes(data=True):
  if attrs['node_type'] not in temp:
    temp[attrs['node_type']] = 1
  else:
    temp[attrs['node_type']] += 1

temp

{'disease': 10687,
 'protein': 17499,
 'drug': 10533,
 'sideeffect': 9969,
 'function': 45085}

In [7]:
#nx.write_gpickle(G, "/content/drive/MyDrive/biokg_graph.gpickle")

Read Graph object stored in pickle format.

In [9]:
# G = nx.read_gpickle("/content/drive/MyDrive/biokg_graph.gpickle")
# print("graph type: ", type(G))
# print("Graph node size: ", len(G))
# print("num of edges: ", G.number_of_edges())

In [14]:
print("input graph type: ", type(G))
hete = HeteroGraph(G)
hete

input graph type:  <class 'networkx.classes.multidigraph.MultiDiGraph'>


HeteroGraph(G=[], edge_feature=[], edge_index=[], edge_label_index=[], edge_to_graph_mapping=[], edge_to_tensor_mapping=[5088434], edge_type=[], node_feature=[], node_label_index=[], node_to_graph_mapping=[], node_to_tensor_mapping=[93773], node_type=[])

In [15]:
# deleting to reduce the memory
del G

In [16]:
graph_dataset = GraphDataset([hete], task='link_pred', \
                             edge_train_mode="disjoint", edge_message_ratio=0.85,)
# Splitting the dataset
dataset_train, dataset_val, dataset_test = graph_dataset.split(transductive=True, split_ratio=[0.8, 0.1, 0.1])

In [17]:
for node in hete.G.nodes(data=True):
    print(node)
    break
for edge in hete.G.edges(data=True):
    print(edge)
    break

(0, {'node_type': 'disease', 'node_feature': tensor([1., 0., 0., 0., 0.])})
(0, 1, {'edge_type': 'disease-protein', 'edge_feature': tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])})


In [18]:
rel = ('drug', 'drug-sideeffect', 'sideeffect')
print(dataset_train[0].edge_index[rel].shape)
print(dataset_train[0].edge_label_index[rel].shape)
print(dataset_train[0].edge_label[rel].shape)

torch.Size([2, 118828])
torch.Size([2, 41940])
torch.Size([41940])


Graph initialization and training 

In [19]:
# Functions to generate two internal GNN layers for link prediction task
def generate_2convs_link_pred_layers(hete, conv, hidden_size):
    convs1 = {}
    convs2 = {}
    for message_type in hete.message_types:
        n_type = message_type[0]
        s_type = message_type[2]
        n_feat_dim = hete.num_node_features(n_type)
        s_feat_dim = hete.num_node_features(s_type)
        convs1[message_type] = conv(n_feat_dim, hidden_size, s_feat_dim)
        convs2[message_type] = conv(hidden_size, hidden_size, hidden_size)
    return convs1, convs2

In [20]:
from pprint import pprint
from deepsnap.hetero_graph import HeteroGraph
from deepsnap.dataset import GraphDataset
from deepsnap.batch import Batch
from deepsnap.hetero_gnn import HeteroSAGEConv
from torch.utils.data import DataLoader

hidden_size = 128

# Generate two heterogeneous GNN layers for link prediction
conv1, conv2 = generate_2convs_link_pred_layers(hete, HeteroSAGEConv, hidden_size)
pprint(conv1)
pprint(conv2)

train_loader = DataLoader(dataset_train, collate_fn=Batch.collate(),
                    batch_size=1)
val_loader = DataLoader(dataset_val, collate_fn=Batch.collate(),
                    batch_size=1)
test_loader = DataLoader(dataset_test, collate_fn=Batch.collate(),
                    batch_size=1)
dataloaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}

{('disease', 'disease-protein', 'protein'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-disease', 'disease'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_acquired_metabolic_disease', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_bacterial_infectious_disease', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_benign_neoplasm', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_cancer', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_cardiovascular_system_disease', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_chromosomal_disease', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_cognitive_disorder', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_cryptorchidism', 'drug'): HeteroSAGEConv(neigh: 5, self: 5, out: 256),
 ('drug', 'drug-drug_developmental_disorder_of_me

In [21]:
# Define the heterogeneous GNN for the link prediction task
class HeteroGNN(torch.nn.Module):
    def __init__(self, conv1, conv2, hetero, hidden_size):
        super(HeteroGNN, self).__init__()
        
        self.convs1 = HeteroConv(conv1, aggr='add') # Wrap the heterogeneous GNN layers
        self.convs2 = HeteroConv(conv2, aggr='add')
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.bns1 = nn.ModuleDict()
        self.bns2 = nn.ModuleDict()
        self.relus1 = nn.ModuleDict()
        self.relus2 = nn.ModuleDict()
        self.post_mps = nn.ModuleDict()

        for node_type in hetero.node_types:
            self.bns1[node_type] = torch.nn.BatchNorm1d(hidden_size)
            self.bns2[node_type] = torch.nn.BatchNorm1d(hidden_size)
            self.relus1[node_type] = nn.LeakyReLU()
            self.relus2[node_type] = nn.LeakyReLU()

    def forward(self, data):
        x = data.node_feature
        edge_index = data.edge_index
        x = self.convs1(x, edge_index)
        x = forward_op(x, self.bns1)
        x = forward_op(x, self.relus1)
        x = self.convs2(x, edge_index)
        x = forward_op(x, self.bns2)

        pred = {}
        for message_type in data.edge_label_index:
            src_type, dst_type = message_type[0], message_type[2]
            nodes_first = torch.index_select(x[src_type], 0, data.edge_label_index[message_type][0,:].long())
            nodes_second = torch.index_select(x[dst_type], 0, data.edge_label_index[message_type][1,:].long())
            pred[message_type] = torch.sum(nodes_first * nodes_second, dim=-1)
        return pred

    def loss(self, pred, y):
        loss = 0
        for key in pred:
            p = torch.sigmoid(pred[key])
            loss += self.loss_fn(p, y[key].type(pred[key].dtype))
        return loss

In [22]:
# Test function
def test(model, dataloaders, args):
    model.eval()
    accs = {}
    for mode, dataloader in dataloaders.items():
        acc = 0
        for i, batch in enumerate(dataloader):
            num = 0
            batch.to(args["device"])
            pred = model(batch)
            for key in pred:
                p = torch.sigmoid(pred[key]).cpu().detach().numpy()
                pred_label = np.zeros_like(p, dtype=np.int64)
                pred_label[np.where(p > 0.5)[0]] = 1
                pred_label[np.where(p <= 0.5)[0]] = 0
                acc += np.sum(pred_label == batch.edge_label[key].cpu().numpy())
                num += len(pred_label)
        accs[mode] = acc / num
    return accs


# Train function
def train(model, dataloaders, optimizer, args):
    val_max = 0
    best_model = model
    t_accu = []
    v_accu = []
    e_accu = []
    for epoch in range(1, args["epochs"] + 1):
        for iter_i, batch in enumerate(dataloaders['train']):
            batch.to(args["device"])
            model.train()
            optimizer.zero_grad()
            pred = model(batch)
            loss = model.loss(pred, batch.edge_label)
            loss.backward()
            optimizer.step()

            log = 'Epoch: {:03d}, Train loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
            accs = test(model, dataloaders, args)
            t_accu.append(accs['train'])
            v_accu.append(accs['val'])
            e_accu.append(accs['test'])

            print(log.format(epoch, loss.item(), accs['train'], accs['val'], accs['test']))
            if val_max < accs['val']:
                val_max = accs['val']
                best_model = copy.deepcopy(model)
                
    log = 'Best: Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    accs = test(best_model, dataloaders, args)
    print(log.format(accs['train'], accs['val'], accs['test']))

    return t_accu, v_accu, e_accu

In [None]:
args = {
    "device": "cuda",
    "epochs": 120,
    "lr": 0.01,
    "weight_decay": 1e-4
}

# Build the model and start training
model = HeteroGNN(conv1, conv2, hete, hidden_size).to(args["device"])
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
t_accu, v_accu, e_accu = train(model, dataloaders, optimizer, args)