<a href="https://colab.research.google.com/github/scigeek72/GNN_Repo/blob/main/practice_3_GNN_graphsage_with_edge_attr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is the 2nd notebook, following practice_2_GNN_graphsage_no_edge_attr.ipynb (see my GNN_repo). This is also a practice notebook done on Colab. 

This is also a practice notebook where I will implement (again copying from another notebook), modified GraphSage that includes `edge_attr` as edge weights i,e 1 dimensional attributes. So the size of the `edge_attr` tensor will be $[|E|, 1]$, where $E$ is the number of edges for the graph  $G= (V,E)$.

In the next practice notebook, I will extend this idea into multi-dimensional `edge_attr` such that the size of the `edge_attr` tensor will be $[|E|,D]$ where $D > 1$.

Below, I will note make any comments that I have already mentioned in the practice_2_GNN_graphsage_no_edge_attr.ipynb notebook. For reference, look at that notebook. 

In [None]:
# Install torch gemoetric (Takes time, each time it is done)
!pip install torch-scatter -f https://data.pyg.org/wh1/torch-1.11.0+cu113.html
!pip install torch-sparse -f https://data.pyg.org/wh1/torch-1.11.0+cu113.html
!pip install torch_gemoetric
!pip install -q git+https://github.com/snap-stanford/deepsnap.git
!pip install ogb # for ddi data that we will end up using 

In [None]:
!pip install torch_geometric

In [None]:
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import negative_sampling
from tqdm import trange 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

### Get Drug-Drug Interaction (DDI) dataset (graph) from OGB

In [None]:
from ogb.linkproppred import PygLinkPropPredDataset

dataset_name = 'ogbl-ddi'

dataset = PygLinkPropPredDataset(name = dataset_name)

print(f'{dataset_name} has length {len(dataset)}')

In [None]:
ddi_graph = dataset[0] # since there is only 1 graph (see the last line of the above cell)

print(f'DDI graph Object: {ddi_graph}')
print(f'Number of nodes: {ddi_graph.num_nodes}')
print(f'Number of edges: {ddi_graph.num_edges}')
print(f'Is Undirected? {ddi_graph.is_undirected()}')

print(f'Average node degree: {ddi_graph.num_edges/ ddi_graph.num_nodes:2f}')
print(f'Number of node features: {ddi_graph.num_node_features}') # has no node features (drugs has no features)
print(f'Number of edge features: {ddi_graph.num_edge_features}') # this data has no edge features, but we will build one
print(f'Has self_loops: {ddi_graph.has_self_loops()}')
print(f'Has isolated nodes: {ddi_graph.has_isolated_nodes()}')

### Data Split

In [None]:
split_edges = dataset.get_edge_split()
split_edges.keys()

In [None]:
train_edges, valid_edges, test_edges = split_edges['train'], split_edges['valid'], split_edges['test']

In [None]:
print(f'{train_edges.keys()}')
print(f'{valid_edges.keys()}')
print(f'{test_edges.keys()}')


In [None]:
print(f'Number of training pos edges: {train_edges["edge"].shape[0]}')
print(f'Number of validation pos edges: {valid_edges["edge"].shape[0]}')
print(f'Number of test pos edges: {test_edges["edge"].shape[0]}')

print(f'Number of validation negative edges: {valid_edges["edge_neg"].shape[0]}')
print(f'Number of test negative edges: {test_edges["edge_neg"].shape[0]}')



In [None]:
print(f'Size of the edge_index: {ddi_graph.edge_index.shape}')

A clarification is in order here. Note in the previous-to-last cell, the number of training edges is ${\frac{1}{2}}*$ the size mentioned in last cell. This is because, this graph is `undirected`, it includes both `(u,v)` and `(v,u)` in the `edge_index`. 

Note: `edge_index` is `COO` format in which PyG keeps the graph structure information (as opposed to matrix format)

### Define edge attributes (1-dim) for each edge. 

In the next notebook, we will build upon this to implement a multi-dimensional edge attributes.

Note: Here we aren't defining how this attribute is built/calculated as I am not interested in how it is built. So ommiting all the comments and theory behind it for now. If curious, please refer to the notebook: **Predicting Drug-Drug Interactions using Graph Neural Networks** from (Stanford course CS224W) [look at the github repo for a link].

In [None]:
from torch_geometric.utils import to_networkx

nx_ddi_graph = to_networkx(ddi_graph, to_undirected=True)

In [None]:
#@title Anchor nodes parameters
num_anchor_nodes = 200 #@param {type:'number'}

In [None]:
import numpy as np
import networkx as nx

In [None]:
anchor_nodes = np.random.choice(nx_ddi_graph.number_of_nodes(), size=num_anchor_nodes, replace=False)

shortest_paths_to_anchor_nodes = torch.zeros(nx_ddi_graph.number_of_nodes(), num_anchor_nodes)

for anchor_index, anchor_node in enumerate(anchor_nodes):
  for dst, path_length in nx.single_source_shortest_path_length(nx_ddi_graph, source=anchor_node, cutoff=5).items():
    shortest_paths_to_anchor_nodes[dst,anchor_index] = path_length

print(f'Shortest paths for node 0 to every anchor node: {shortest_paths_to_anchor_nodes[0]}')

In [None]:
shortest_paths_to_anchor_nodes.shape

edge attribute between edge $(u,v)$ is defined as: $$\text{mean}(\{SPD(u,v_a) + SPD(v,v_a), \forall v_a \in V_{\text{anchor}} \})$$ **SPD stands for Shortest Path Distance

In [None]:
edge_attr = shortest_paths_to_anchor_nodes[ddi_graph.edge_index, :].sum(dim=0).mean(dim=1, keepdim=True).to(device)

#normalize the edge_attrs 
max_attr = torch.max(edge_attr)
min_attr = torch.min(edge_attr)

edge_attr = (edge_attr - min_attr)/(max_attr - min_attr + 1e-15) 

In [None]:
edge_attr.shape

## Define a custom GNN layer by incorporating edge_attr for the GraphSAGE model

Recap of the GraphSAGE model equation:

$$h_v^{l+1} = W_1 ⋅ h_v^l + W_2 \cdot \text{mean}(\{h_u^l, \forall u \in N_v\})$$ where $N_v$ is the neighbors of $v$.

Definition of custom GrapSAGE model:

$$h_v^{l+1} = W_1 \cdot h_v^l + W_2 \cdot \text{mean}\{\text{ReLU}(m_{vu}^{l+1}), \forall u \in N_v\}$$ where $m_{vu}^{l+1} = h_u^l + W_3 \cdot \text{edge_attr}_{vu}$

#### Implementation Note
As noted in the previous practice notebook, $W_1$, $W_2$ and $W_3$ are all implemented using a `torch.nn.Linear` layer with appropriate `(in_channel, out_channel)` so that all the summation in the above equation are defined. In particular, $W_3$'s `in_channel` should be $1$ as `edge_attr` has shape $[|E|,1]$. 

In [None]:
graphsage_in_channels = 128
graphsage_hidden_channels = 128
graphsage_out_channels = 128
#link_predictor_in_channels = link_predictor_in_channels # need to implement for this notebook
edge_attr_out_channels = graphsage_hidden_channels

In [None]:
# Custom GNN Layer
from torch_geometric.nn.conv import MessagePassing

class SAGEConvWithEdgeAttr(MessagePassing):
  def __init__(self, in_channels, out_channels, aggr='mean', comb=lambda x,y:x+y, normalize=True,**kwargs):
    super(SAGEConvWithEdgeAttr, self).__init__(aggr=aggr)

    # W_1 
    self.w_self_embedding = torch.nn.Linear(in_channels, out_channels)

    # W_2
    self.w_aggregation = torch.nn.Linear(in_channels, out_channels)

    # W_3
    self.w_edge_attr = torch.nn.Linear(1, edge_attr_out_channels)

    self.comb = comb
    self.normalize = normalize

  def forward(self, x, edge_index, edge_attr, size=None):
    # message propagation
    aggregation = self.propagate(edge_index, x=(x,x), edge_attr=edge_attr, size=size)
    out = self.w_self_embedding(x) + self.w_aggregation(aggregation)

    if self.normalize:
      out = F.normalize(out)

    return out 

  def message(self, x_j, edge_attr):
    """ Cursome Message """
    return F.relu(self.comb(x_j, self.w_edge_attr(edge_attr)))


In [None]:
# Following is the GNN with custom layer defined above

class GraphSAGE(torch.nn.Module):

  def __init__(self, conv, in_channels, hidden_channels, out_channels, num_layers, dropout):
    # conv is the custom gnn layer that we built in the preceeding cell
    super(GraphSAGE, self).__init__()

    self.convs = torch.nn.ModuleList()
    # must have atleast 2 gnn layers
    assert (num_layers >= 2), 'Have at least 2 layers'

    self.convs.append(conv(in_channels, hidden_channels, normalize=True))
    for l in range(num_layers - 2):
      self.convs.append(conv(hidden_channels, hidden_channels, normalize=True))
    self.convs.append(conv(hidden_channels, out_channels, normalize=True))

    self.num_layers = num_layers
    self.dropout = dropout

  def forward(self, x, edge_index, edge_attr):
    for i in range(self.num_layers-1):
      # apply the custom layer to the layer i of gnn
      x = self.convs[i](x, edge_index, edge_attr)
      # pass through non-linearity
      x = F.relu(x)
      x = F.dropout(x,p=self.dropout, training=self.training)

    x = self.convs[self.num_layers-1](x, edge_index, edge_attr) # x.shape = [N,out_channels]

    return x 




In [None]:
# This is exclusively for this problem

class LinkPredictor(torch.nn.Module):
  def __init__(self, in_channels,hidden_channels,dropout, out_channels=1,el_prod=lambda x,y: x*y):
    super(LinkPredictor,self).__init__()
    self.model = nn.Sequential(nn.Linear(in_channels,hidden_channels),
                               nn.ReLU(),
                               nn.Dropout(p=dropout),
                               nn.Linear(hidden_channels,out_channels),
                               nn.Sigmoid())
    
    self.el_prod = el_prod

  def forward(self,u,v):
    x = self.el_prod(u,v)
    return self.model(x)


In [None]:
graphsage_in_channels = 128
graphsage_hidden_channels = graphsage_in_channels
graphsage_out_channels = graphsage_hidden_channels
graphsage_num_layers = 2
link_predictor_in_channels = graphsage_out_channels
link_predictor_hidden_channels = link_predictor_in_channels
edge_attr_out_channels = graphsage_hidden_channels
dropout = 0.5

In [None]:
initial_node_embeddings = torch.nn.Embedding(ddi_graph.num_nodes, graphsage_in_channels).to(device)

graphsage_model = GraphSAGE(SAGEConvWithEdgeAttr, graphsage_in_channels,
                            graphsage_hidden_channels, graphsage_out_channels,
                            graphsage_num_layers,
                            dropout).to(device)

In [None]:
link_predictor = LinkPredictor(in_channels=link_predictor_in_channels,
                               hidden_channels=link_predictor_hidden_channels,
                               dropout=dropout).to(device)

In [None]:
def train(graphsage_model, link_predictor, initial_node_embeddings,edge_index,
          pos_train_edges, optimizer, batch_size, edge_attr):
  
  total_loss, total_examples = 0,0

  # not sure about the following 2 lines
  graphsage_model.train()
  link_predictor.train()

  node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)

  for pos_samples in DataLoader(pos_train_edges, batch_size, shuffle=True):
    optimizer.zero_grad()

    #node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)

    neg_samples = negative_sampling(edge_index, 
                                   num_nodes=initial_node_embeddings.shape[0],
                                   num_neg_samples=len(pos_samples),
                                   method='dense')
    
    pos_preds = link_predictor(node_embeddings[pos_samples[:,0]],
                              node_embeddings[pos_samples[:,1]])
    
    neg_preds = link_predictor(node_embeddings[neg_samples[0]],
                               node_embeddings[neg_samples[1]])
    
    preds = torch.concat((pos_preds, neg_preds))
    labels = torch.concat((torch.ones_like(pos_preds),
                           torch.zeros_like(neg_preds)))
    
    loss = F.binary_cross_entropy(preds, labels)

    loss.backward(retain_graph = True)
    optimizer.step()

    num_examples = len(pos_preds)
    total_loss += loss.item() * num_examples
    total_examples += num_examples

  return total_loss/total_examples, node_embeddings 


In [None]:
#@title Training Parameters
lr = 0.001 #@param {type: 'number'}
batch_size= 65536 #@param {type:'number'}
epochs = 20 #@param {type: 'number'}
eval_steps = 5 #@param {type: 'number'}

In [None]:
optimizer = torch.optim.Adam(list(initial_node_embeddings.parameters()) + list(graphsage_model.parameters()) + list(link_predictor.parameters()), lr = lr)

In [None]:
pos_valid_edges = valid_edges['edge'].to(device)
neg_valid_edges = valid_edges['edge_neg'].to(device)

pos_test_edges = test_edges['edge'].to(device)
neg_test_edges = test_edges['edge_neg'].to(device)

In [None]:
from ogb.linkproppred import Evaluator

evaluator = Evaluator(name=dataset_name)

In [None]:
from torch_geometric.data.data import Data
@torch.no_grad()
def test(graphsage_model, link_predictor,initial_node_embeddings,edge_index,
         pos_valid_edges, pos_test_edges,
         neg_valid_edges, neg_test_edges,
         batch_size, evaluator, edge_attr):
  graphsage_model.eval()
  link_predictor.eval()

  final_node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)

  pos_valid_preds = []

  for pos_samples in DataLoader(pos_valid_edges, batch_size):
    
    pos_preds = link_predictor(final_node_embeddings[pos_samples[:,0]],
                               final_node_embeddings[pos_samples[:,1]])
    
    pos_valid_preds.append(pos_preds.squeeze())

  pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

  neg_valid_preds = []
  for neg_samples in DataLoader(neg_valid_edges,batch_size):
    neg_preds = link_predictor(final_node_embeddings[neg_samples[:,0]],
                               final_node_embeddings[neg_samples[:,1]])
    
    neg_valid_preds.append(neg_preds.squeeze())
  
  neg_valid_pred = torch.cat(neg_valid_preds, dim=0)


  pos_test_preds = []

  for pos_samples in DataLoader(pos_test_edges, batch_size):
    pos_preds = link_predictor(final_node_embeddings[pos_samples[:,0]],
                               final_node_embeddings[pos_samples[:,1]])
    pos_test_preds.append(pos_preds.squeeze())

  pos_test_pred = torch.cat(pos_test_preds, dim=0)

  neg_test_preds = []

  for neg_samples in DataLoader(neg_test_edges, batch_size):
    neg_preds = link_predictor(final_node_embeddings[neg_samples[:,0]],
                               final_node_embeddings[neg_samples[:,1]])
    neg_test_preds.append(neg_preds.squeeze())

  neg_test_pred = torch.cat(neg_test_preds, dim = 0)

  # Calculate Hits@20 (problem specific)
  evaluator.K = 20
  valid_hits = evaluator.eval({'y_pred_pos':pos_valid_pred, 'y_pred_neg': neg_valid_pred})
  test_hits = evaluator.eval({'y_pred_pos': pos_test_pred, 'y_pred_neg': neg_test_pred})

  return valid_hits, test_hits



In [None]:
import matplotlib.pyplot as plt


In [None]:
# run training and evaluation

epochs_bar = trange(1, epochs + 1, desc = 'Loss n/a')

edge_index = ddi_graph.edge_index.to(device)
pos_train_edges = train_edges['edge'].to(device)

losses = []
valid_hits_list = []
test_hits_list = []

for epoch in epochs_bar:
  loss, h = train(graphsage_model, link_predictor, initial_node_embeddings.weight,
               edge_index, pos_train_edges, optimizer, batch_size, edge_attr)
  
  losses.append(loss)
  epochs_bar.set_description(f"Loss {loss:0.4f}")

  if epoch % eval_steps == 0:
    valid_hits, test_hits = test(graphsage_model, link_predictor, initial_node_embeddings.weight, edge_index,
                                 pos_valid_edges, neg_valid_edges, pos_test_edges, neg_test_edges, batch_size,
                                 evaluator, edge_attr)
    print()
    print(f'Epoch: {epoch}, Validation Hits@20: {valid_hits["hits@20"]:0.4f}, Test Hits@20:{test_hits["hits@20"]:0.4f}')
    valid_hits_list.append(valid_hits_list[-1] if valid_hits_list else 0)
    test_hits_list.append(test_hits_list[-1] if test_hits_list else 0)

  else:
    valid_hits_list.append(valid_hits_list[-1] if valid_hits_list else 0)
    test_hits_list.append(test_hits_list[-1] if test_hits_list else 0)

plt.title(dataset.name + ": GraphSAGE with edge attributes")
plt.xlabel("Epoch")
plt.plot(losses, label="Training loss")
plt.plot(valid_hits_list, label="Validation Hits@20")
plt.plot(test_hits_list, label="Test Hits@20")
plt.legend()
plt.show()
  

In [None]:
epochs_bar = trange(1, epochs+1, desc= 'Loss n/a')
edge_index = ddi_graph.edge_index.to(device)
pos_train_edges = train_edges['edge'].to(device)
for epoch in epochs_bar:
  loss, h = train(graphsage_model, link_predictor, initial_node_embeddings.weight,
               edge_index, pos_train_edges, optimizer, batch_size, edge_attr)