In [None]:
# default_exp models.layers.message_passing

# Message Passing
> Implementation of message passing graph network layers like LightGCN, LR-GCCF etc.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import torch
from torch import Tensor
from torch import nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [None]:
#export
class LightGConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')
        
    def forward(self,x,edge_index):
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)
    
    def message(self,x_j,norm):
        return norm.view(-1,1) * x_j
        
    def update(self,inputs: Tensor) -> Tensor:
        return inputs

In [None]:
#export
class LRGCCF(MessagePassing):
    def __init__(self, in_channels,out_channels):
        super(LRGCCF,self).__init__(aggr='mean')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self,x,edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0));
        return self.lin(self.propagate(edge_index,x=x))

    def message(self,x_j):
        return x_j
        
    def update(self,inputs: Tensor) -> Tensor:
        return inputs

In [None]:
import pandas as pd

train = pd.DataFrame(
    {'userId':[1,1,2,2,3,4,5],
     'itemId':[1,2,1,3,2,4,5],
     'rating':[4,5,2,5,3,2,4]}
)

train

Unnamed: 0,userId,itemId,rating
0,1,1,4
1,1,2,5
2,2,1,2
3,2,3,5
4,3,2,3
5,4,4,2
6,5,5,4


In [None]:
from torch_geometric.data import Data

E = nn.Parameter(torch.empty(5, 5))

edge_user = torch.tensor(train[train['rating']>3]['userId'].values-1)
edge_item = torch.tensor(train[train['rating']>3]['itemId'].values-1)
edge_ = torch.stack((torch.cat((edge_user,edge_item),0),torch.cat((edge_item,edge_user),0)),0)
data_p = Data(edge_index=edge_)

In [None]:
torch.random.manual_seed(0)
lightgconv = LightGConv()
lightgconv(E, data_p.edge_index)

tensor([[-1.1105e+34,  2.0593e-41,  7.2868e-44,  7.2868e-44,  7.4269e-44],
        [-6.8002e+33,  1.2643e-41,  8.2677e-44,  8.5479e-44,  7.7071e-44],
        [ 4.9045e-44,  4.9045e-44,  4.4842e-44,  4.9045e-44,  5.6052e-44],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 7.8473e-44,  6.7262e-44,  7.8473e-44,  7.8473e-44,  7.2868e-44]],
       grad_fn=<ScatterAddBackward0>)

In [None]:
torch.random.manual_seed(0)
lrgccf = LRGCCF(5,5)
lrgccf(E, data_p.edge_index)

tensor([[ 4.1829e+31, -1.4982e+33,  1.6885e+33, -2.0696e+32, -2.0293e+33],
        [ 1.8590e+31, -6.6586e+32,  7.5042e+32, -9.1983e+31, -9.0190e+32],
        [ 4.7322e-02,  4.0494e-01, -4.1487e-01, -2.8154e-01, -1.1322e-01],
        [ 4.7322e-02,  4.0494e-01, -4.1487e-01, -2.8154e-01, -1.1322e-01],
        [ 4.7322e-02,  4.0494e-01, -4.1487e-01, -2.8154e-01, -1.1322e-01]],
       grad_fn=<AddmmBackward0>)

In [None]:
#hide
!pip install -q watermark
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d -p torch_geometric

Author: Sparsh A.

Last updated: 2021-12-19 17:51:53

torch_geometric: 2.0.2

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.104+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

numpy  : 1.19.5
IPython: 5.5.0
pandas : 1.1.5
torch  : 1.10.0+cu111

