# Graph Neural Networks
## What are Graph Neural Networks (GNNs)?

In [1]:
#import the basics
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch_geometric
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_sparse import SparseTensor
%matplotlib inline



In [2]:
torch_geometric.seed_everything(1234)
torch_geometric.__version__

'2.6.1'

In [3]:
# Let's verify what device we are working with
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("You are using device: %s" % device)

You are using device: cpu


Graph Neural Networks are a type of "geometric deep learning" models that use pairwise message passing. They typically have an architecture consisting of 3 types of layers. From [wikipedia](https://en.wikipedia.org/wiki/Graph_neural_network):
1. Permutation equivariant: a permutation equivariant layer maps a representation of a graph into an updated representation of the same graph. In the literature, permutation equivariant layers are implemented via **pairwise message passing between graph nodes**. Intuitively, in a message passing layer, nodes update their representations by aggregating the messages received from their immediate neighbours. As such, each message passing layer increases the receptive field of the GNN by one hop.
2. Local pooling: a local pooling layer coarsens the graph via downsampling. Local pooling is used to increase the receptive field of a GNN, in a similar fashion to pooling layers in convolutional neural networks. Examples include k-nearest neighbours pooling, top-k pooling, and self-attention pooling.
3. Global pooling: a global pooling layer, also known as readout layer, provides fixed-size representation of the whole graph. The global pooling layer must be permutation invariant, such that permutations in the ordering of graph nodes and edges do not alter the final output. Examples include element-wise sum, mean or maximum.

## Attributes
- [T]he preprocessing step first
“squashes” the graph structured data into a vector of reals and
then deals with the preprocessed data using a list-based data
processing technique. However, important information, e.g., the
topological dependency of information on each node may be
lost during the preprocessing stage and the final result may depend, in an unpredictable manner, on the details of the preprocessing algorith [1] **GNNS preserve the structure of the graph it is based on.**
- It will be shown that the GNN
is an extension of both recursive neural networks and random
walk models and that it retains their characteristics. The model
extends recursive neural networks since it can process a more
general class of graphs including cyclic, directed, and undirected graphs, and it can deal with node-focused applications
without any preprocessing steps. The approach extends random
walk theory by the introduction of a learning algorithm and by
enlarging the class of processes that can be modeled. [1]
- Weights are shared across layer structures

### What is message passing?
From [wikipedia](https://en.wikipedia.org/wiki/Graph_neural_network#Message_passing_layers):
<br>
![img](./img/notebook/messagePassing.png)

## Computation Graph
"The neighbour of a node defines its computation graph" - @12:34 https://www.youtube.com/watch?v=JtDgmmQ60x8&ab_channel=AntonioLonga



## Data

In [4]:
movie_path = './data/MovieLens/raw/ml-latest-small/movies.csv'
rating_path = './data/MovieLens/raw/ml-latest-small/ratings.csv'

In [5]:
def load_node_csv(path, index_col):
    """Loads csv containing node information
    Args:
        path (str): path to csv file
        index_col (str): column name of index column
    Returns:
        dict: mapping of csv row to node id
    """
    df = pd.read_csv(path, index_col=index_col)
    mapping = {index: i for i, index in enumerate(df.index.unique())}
    return mapping
user_mapping = load_node_csv(rating_path, index_col='userId')
movie_mapping = load_node_csv(movie_path, index_col='movieId')

In [6]:
print(f"user_mapping size: {len(user_mapping)}")
print(f"movie_mapping size: {len(movie_mapping)}")

user_mapping size: 610
movie_mapping size: 9742


In [7]:
def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping, link_index_col, rating_threshold=4):
    """Loads csv containing edges between users and items

    Args:
        path (str): path to csv file
        src_index_col (str): column name of users
        src_mapping (dict): mapping between row number and user id
        dst_index_col (str): column name of items
        dst_mapping (dict): mapping between row number and item id
        link_index_col (str): column name of user item interaction
        rating_threshold (int, optional): Threshold to determine positivity of edge. Defaults to 4.

    Returns:
        torch.Tensor: 2 by N matrix containing the node ids of N user-item edges
    """
    df = pd.read_csv(path)
    edge_index = None
    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_attr = torch.from_numpy(df[link_index_col].values).view(-1, 1).to(torch.long) >= rating_threshold


    edge_index = [[], []]
    for i in range(edge_attr.shape[0]):
        if edge_attr[i]:
            edge_index[0].append(src[i])
            edge_index[1].append(dst[i])

    return torch.tensor(edge_index)


edge_index = load_edge_csv(
    rating_path,
    src_index_col='userId',
    src_mapping=user_mapping,
    dst_index_col='movieId',
    dst_mapping=movie_mapping,
    link_index_col='rating',
    rating_threshold=4,
)

In [8]:
edge_index[:,:5]

tensor([[ 0,  0,  0,  0,  0],
        [ 0,  2,  5, 43, 46]])

In [9]:
edge_index.size()

torch.Size([2, 48580])

In [10]:
from sklearn.model_selection import train_test_split

TOTAL_NUM_USERS, TOTAL_NUM_MOVIES = len(user_mapping), len(movie_mapping)
num_interactions = edge_index.shape[1]
all_indices = [i for i in range(num_interactions)]

train_indices, test_indices = train_test_split(
    all_indices, test_size=0.2, random_state=1)
val_indices, test_indices = train_test_split(
    test_indices, test_size=0.5, random_state=1)

train_edge_index = edge_index[:, train_indices]
val_edge_index = edge_index[:, val_indices]
test_edge_index = edge_index[:, test_indices]

In [11]:
train_edge_index.size()[1]/edge_index.size()[1]

0.8

In [12]:
val_edge_index.size()[1]/edge_index.size()[1]

0.1

In [13]:
test_edge_index.size()[1]/edge_index.size()[1]

0.1

In [14]:
train_sparse_edge_index = SparseTensor(row=train_edge_index[0], col=train_edge_index[1], sparse_sizes=(
    TOTAL_NUM_USERS + TOTAL_NUM_MOVIES, TOTAL_NUM_USERS + TOTAL_NUM_MOVIES))
val_sparse_edge_index = SparseTensor(row=val_edge_index[0], col=val_edge_index[1], sparse_sizes=(
    TOTAL_NUM_USERS + TOTAL_NUM_MOVIES, TOTAL_NUM_USERS + TOTAL_NUM_MOVIES))
test_sparse_edge_index = SparseTensor(row=test_edge_index[0], col=test_edge_index[1], sparse_sizes=(
    TOTAL_NUM_USERS + TOTAL_NUM_MOVIES, TOTAL_NUM_USERS + TOTAL_NUM_MOVIES))

In [15]:
torch.stack([train_sparse_edge_index.coo()[0], train_sparse_edge_index.coo()[1]], dim=0)

tensor([[   0,    0,    0,  ...,  609,  609,  609],
        [   0,    2,   43,  ..., 9461, 9462, 9463]])

In [16]:
from torch_geometric.data import HeteroData

train_data = HeteroData()
train_data['user'].num_nodes = TOTAL_NUM_USERS
train_data['movie'].num_nodes = TOTAL_NUM_MOVIES
train_data['user','rates','movie'].edge_index = torch.stack([train_sparse_edge_index.coo()[0], train_sparse_edge_index.coo()[1]], dim=0)

In [17]:
train_data

HeteroData(
  user={ num_nodes=610 },
  movie={ num_nodes=9742 },
  (user, rates, movie)={ edge_index=[2, 38864] }
)

In [18]:
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.sampler import NegativeSampling

train_loader = LinkNeighborLoader(
    train_data,
    batch_size=1000,
    num_neighbors=[30,20,10],
    edge_label_index=('user','rates','movie'),
    is_sorted=True,
    shuffle=True,
)

In [19]:
sampled_data = next(iter(train_loader))
sampled_data

HeteroData(
  user={
    num_nodes=374,
    n_id=[374],
  },
  movie={
    num_nodes=707,
    n_id=[707],
  },
  (user, rates, movie)={
    edge_index=[2, 57],
    e_id=[57],
    input_id=[1000],
    edge_label_index=[2, 1000],
  }
)

In [21]:
sampled_data['user','rates','movie']['edge_index']

tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0, 270,  34, 198, 368, 202,   5,
         305, 162, 318, 152, 305,  82, 301, 369,  42, 305, 194, 290, 370, 116,
          48, 193,  61, 278, 155, 224, 257, 290, 181, 371, 372, 373, 373, 373,
         373],
        [  0,   2,  16,  17,  24,  27,  32,  45,  63,  67,  70,  72,  75,  82,
          90, 100, 101, 102, 105, 107, 111, 113, 117, 117, 117, 117, 117, 117,
         117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117,
         117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 253, 615, 623, 633,
         680]])

In [23]:
sampled_data['user','rates','movie']['e_id']

tensor([    0,     1,     2,     3,     4,     5,     8,    10,    11,    12,
           13,    14,    15,    17,    19,    21,    22,    23,    24,    25,
           26,    27, 27041,  2576, 19317, 18873, 20033,   433, 31243, 15137,
        32472, 13850, 31378,  7780, 30913,  7720,  3345, 31244, 19060, 28956,
        22332, 10432,  3973, 18772,  5414, 27739, 14100, 21850, 24945, 29452,
        17250,  9855, 32788, 32834, 32835, 32836, 32839])

In [20]:
val_data = HeteroData()
val_data['user'].num_nodes = TOTAL_NUM_USERS
val_data['movie'].num_nodes = TOTAL_NUM_MOVIES
val_data['user','rates','movie'].edge_index = torch.stack([val_sparse_edge_index.coo()[0], val_sparse_edge_index.coo()[1]], dim=0)

val_loader = LinkNeighborLoader(
    val_data,
    batch_size=1000,
    num_neighbors=[30,20,10],
    edge_label_index=('user','rates','movie'),
    is_sorted=True,
    shuffle=True,
)

sampled_data = next(iter(train_loader))
sampled_data

HeteroData(
  user={
    num_nodes=363,
    n_id=[363],
  },
  movie={
    num_nodes=731,
    n_id=[731],
  },
  (user, rates, movie)={
    edge_index=[2, 86],
    e_id=[86],
    input_id=[1000],
    edge_label_index=[2, 1000],
  }
)

## Neural Graph Collaborative Filtering



In [21]:
from torch_geometric.nn import Linear
from torch.nn import Embedding, Parameter
from torch.nn import functional as F

class EmbeddingPropLayer(torch.nn.Module):
    def __init__(self, hidden_channels=128):
        super(EmbeddingPropLayer, self).__init__()

        self.W1 = Linear(hidden_channels, hidden_channels, bias=False)
        self.W2 = Linear(hidden_channels, hidden_channels, bias=False)

    def reset_parameters(self):
        self.W1.reset_parameters()
        self.W2.reset_parameters()

    def forward(self, E, E_final):
        message = self.message_aggregation(E)
        return message, torch.concat([E_final, message], dim=1)

    '''

    '''
    def message_construction(self, E, p_ui):
        return torch.mul(p_ui*E, p_ui*self.W2(E))

    '''

    '''
    def message_aggregation(self, E):
        p_ui = self._com
        m_ui = self.message_construction(E, p_ui)
        m_uu = p_ui*self.W1(E)
        return F.leaky_relu(m_uu + m_ui)

    def _compute_normalized_adjacency(self, ei):
        """Compute normalized adjacency matrix A' = D^(-1/2) * A * D^(-1/2)"""
        edge_index_norm = gcn_norm(
            ei,
            add_self_loops=self.add_self_loops
        )
        return edge_index_norm

class NGCF(torch.nn.Module):
    def __init__(self, num_of_users, num_of_movies, hidden_channels=128):
        super(NGCF, self).__init__()
        self.num_of_users = num_of_users
        self.num_of_movies = num_of_movies
        self.hidden_channels = hidden_channels

        self.user_emb_layer = Embedding(self.num_of_users, hidden_channels)
        self.movie_emb_layer = Embedding(self.num_of_movies, hidden_channels)

        self.embedding_prop_layer_1 = EmbeddingPropLayer(hidden_channels)
        self.embedding_prop_layer_2 = EmbeddingPropLayer(hidden_channels)
        self.embedding_prop_layer_3 = EmbeddingPropLayer(hidden_channels)

    def forward(self, fdata, debug=False):
        e_u_0 = self.user_emb_layer(fdata['user','rates','movie']['edge_label_index'][0])
        e_i_0 = self.movie_emb_layer(fdata['user','rates','movie']['edge_label_index'][1])
        E = torch.concat([e_u_0, e_i_0], dim=0)

        if debug: print(f"E size: {E.size()}")

        #assert E.size()[0] == self.num_of_edges and E.size()[1] == self.hidden_channels

        E_1, E_star = self.embedding_prop_layer_1(E, torch.empty_like(E)) #E_l -> [num_users+num_movies,
        E_2, E_star = self.embedding_prop_layer_2(E_1, E_star)
        E_3, E_star = self.embedding_prop_layer_2(E_2, E_star)

        #assert E_star.size()[0] == self.num_of_edges and E_star.size()[1] == self.hidden_channels*4

        split_point = len(fdata['user','rates','movie']['edge_label_index'][0])

        if debug:
            print(f"E_star size: {E_star.size()}")
            print(f"split point: {split_point}")

        e_u_star = E_star[:split_point]
        e_i_star = E_star[split_point:]

        if debug:
            print(f"E size {E.size()}")
            print(f"E star size {E_star.size()}")
            print(f"e_u star size {e_u_star.size()}")
            print(f"e_i star size {e_i_star.size()}")
            print(f"e_u 0 size {e_u_0.size()}")
            print(f"e_i 0 size {e_i_0.size()}")

        # users_emb_final, users_emb_0, items_emb_final, items_emb_0
        return e_u_star, e_u_0, e_i_star, e_i_0


In [22]:
def bpr_loss(users_emb_final, users_emb_0, pos_items_emb_final, pos_items_emb_0, neg_items_emb_final, neg_items_emb_0, lambda_val):
    """Bayesian Personalized Ranking Loss as described in https://arxiv.org/abs/1205.2618
    Args:
        users_emb_final (torch.Tensor): e_u_k
        users_emb_0 (torch.Tensor): e_u_0
        pos_items_emb_final (torch.Tensor): positive e_i_k
        pos_items_emb_0 (torch.Tensor): positive e_i_0
        neg_items_emb_final (torch.Tensor): negative e_i_k
        neg_items_emb_0 (torch.Tensor): negative e_i_0
        lambda_val (float): lambda value for regularization loss term

    Returns:
        torch.Tensor: scalar bpr loss value
    """
    reg_loss = lambda_val * (users_emb_0.norm(2).pow(2) +
                             pos_items_emb_0.norm(2).pow(2) +
                             neg_items_emb_0.norm(2).pow(2))

    # print(f"user emb final: {users_emb_final}")
    # print(f"pos_items_emb_final: {pos_items_emb_final}")

    pos_scores = torch.mul(users_emb_final, pos_items_emb_final)
    pos_scores = torch.sum(pos_scores, dim=-1)
    neg_scores = torch.mul(users_emb_final, neg_items_emb_final)
    neg_scores = torch.sum(neg_scores, dim=-1)
    loss = -torch.mean(torch.nn.functional.softplus(pos_scores - neg_scores)) + reg_loss

    # print(f"pos scores: {pos_scores}")
    # print(f"neg scores: {neg_scores}")
    # print(f"pos - neg: {pos_scores - neg_scores}")
    # print(f"reg_loss: {reg_loss}")
    # print(f"Loss: {loss}")

    return loss

In [23]:
def get_user_positive_items(edge_index):
    """Generates dictionary of positive items for each user

    Args:
        edge_index (torch.Tensor): 2 by N list of edges

    Returns:
        dict: dictionary of positive items for each user
    """
    user_pos_items = {}
    for i in range(edge_index.shape[1]):
        user = edge_index[0][i].item()
        item = edge_index[1][i].item()
        if user not in user_pos_items:
            user_pos_items[user] = []
        user_pos_items[user].append(item)
    return user_pos_items

In [24]:
# computes recall@K and precision@K
def RecallPrecision_ATk(groundTruth, r, k):
    """Computers recall @ k and precision @ k

    Args:
        groundTruth (list): list of lists containing highly rated items of each user
        r (list): list of lists indicating whether each top k item recommended to each user
            is a top k ground truth item or not
        k (intg): determines the top k items to compute precision and recall on

    Returns:
        tuple: recall @ k, precision @ k
    """
    num_correct_pred = torch.sum(r, dim=-1)  # number of correctly predicted items per user
    # number of items liked by each user in the test set
    user_num_liked = torch.Tensor([len(groundTruth[i])
                                  for i in range(len(groundTruth))])
    recall = torch.mean(num_correct_pred / user_num_liked)
    precision = torch.mean(num_correct_pred) / k
    return recall.item(), precision.item()

In [25]:
# computes NDCG@K
def NDCGatK_r(groundTruth, r, k):
    """Computes Normalized Discounted Cumulative Gain (NDCG) @ k

    Args:
        groundTruth (list): list of lists containing highly rated items of each user
        r (list): list of lists indicating whether each top k item recommended to each user
            is a top k ground truth item or not
        k (int): determines the top k items to compute ndcg on

    Returns:
        float: ndcg @ k
    """
    assert len(r) == len(groundTruth)

    test_matrix = torch.zeros((len(r), k))

    for i, items in enumerate(groundTruth):
        length = min(len(items), k)
        test_matrix[i, :length] = 1
    max_r = test_matrix
    idcg = torch.sum(max_r * 1. / torch.log2(torch.arange(2, k + 2)), axis=1)
    dcg = r * (1. / torch.log2(torch.arange(2, k + 2)))
    dcg = torch.sum(dcg, axis=1)
    idcg[idcg == 0.] = 1.
    ndcg = dcg / idcg
    ndcg[torch.isnan(ndcg)] = 0.
    return torch.mean(ndcg).item()

In [26]:
def get_metrics(model, edge_index, exclude_edge_indices, k):
    """Computes the evaluation metrics: recall, precision, and ndcg @ k

    Args:
        model (LighGCN): lightgcn model
        edge_index (torch.Tensor): 2 by N list of edges for split to evaluate
        exclude_edge_indices ([type]): 2 by N list of edges for split to discount from evaluation
        k (int): determines the top k items to compute metrics on

    Returns:
        tuple: recall @ k, precision @ k, ndcg @ k
    """
    user_embedding = model.user_emb_layer.weight
    item_embedding = model.movie_emb_layer.weight

    # get ratings between every user and item - shape is num users x num movies
    rating = torch.matmul(user_embedding, item_embedding.T)

    for exclude_edge_index in exclude_edge_indices:
        user_pos_items = get_user_positive_items(exclude_edge_index)
        exclude_users = []
        exclude_items = []
        for user, items in user_pos_items.items():
            exclude_users.extend([user] * len(items))
            exclude_items.extend(items)

        rating[exclude_users, exclude_items] = -(1 << 10)

    _, top_K_items = torch.topk(rating, k=k)

    users = edge_index[0].unique()

    test_user_pos_items = get_user_positive_items(edge_index)

    test_user_pos_items_list = [
        test_user_pos_items[user.item()] for user in users]

    r = []
    for user in users:
        ground_truth_items = test_user_pos_items[user.item()]
        label = list(map(lambda x: x in ground_truth_items, top_K_items[user]))
        r.append(label)
    r = torch.Tensor(np.array(r).astype('float'))

    recall, precision = RecallPrecision_ATk(test_user_pos_items_list, r, k)
    ndcg = NDCGatK_r(test_user_pos_items_list, r, k)

    return recall, precision, ndcg

In [27]:
def evaluation(model, edge_index, sparse_edge_index, exclude_edge_indices, k, lambda_val):
    """Evaluates model loss and metrics including recall, precision, ndcg @ k

    Args:
        model (LighGCN): lightgcn model
        edge_index (torch.Tensor): 2 by N list of edges for split to evaluate
        sparse_edge_index (sparseTensor): sparse adjacency matrix for split to evaluate
        exclude_edge_indices ([type]): 2 by N list of edges for split to discount from evaluation
        k (int): determines the top k items to compute metrics on
        lambda_val (float): determines lambda for bpr loss

    Returns:
        tuple: bpr loss, recall @ k, precision @ k, ndcg @ k
    """

    users_emb_final, users_emb_0, items_emb_final, items_emb_0 = model.forward(sparse_edge_index, debug=False)
    edges = structured_negative_sampling(edge_index, contains_neg_self_loops=False)

    user_indices, pos_item_indices, neg_item_indices = edges[0], edges[1], edges[2]
    users_emb_final, users_emb_0 = users_emb_final[user_indices], users_emb_0[user_indices]

    pos_items_emb_final, pos_items_emb_0 = items_emb_final[
        pos_item_indices], items_emb_0[pos_item_indices]
    neg_items_emb_final, neg_items_emb_0 = items_emb_final[
        neg_item_indices], items_emb_0[neg_item_indices]

    loss = bpr_loss(users_emb_final, users_emb_0, pos_items_emb_final, pos_items_emb_0,
                    neg_items_emb_final, neg_items_emb_0, lambda_val).item()

    recall, precision, ndcg = get_metrics(
        model, edge_index, exclude_edge_indices, k)

    return loss, recall, precision, ndcg

In [28]:
ITERATIONS = 100
BATCH_SIZE = 200
LR = 1e-3
ITERS_PER_EVAL = 200
ITERS_PER_LR_DECAY = 200
K = 20
LAMBDA = 1e-6

In [29]:
model = NGCF(TOTAL_NUM_USERS, TOTAL_NUM_MOVIES, hidden_channels=128)
model.train()
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

print(model)

NGCF(
  (user_emb_layer): Embedding(610, 128)
  (movie_emb_layer): Embedding(9742, 128)
  (embedding_prop_layer_1): EmbeddingPropLayer(
    (W1): Linear(128, 128, bias=False)
    (W2): Linear(128, 128, bias=False)
  )
  (embedding_prop_layer_2): EmbeddingPropLayer(
    (W1): Linear(128, 128, bias=False)
    (W2): Linear(128, 128, bias=False)
  )
  (embedding_prop_layer_3): EmbeddingPropLayer(
    (W1): Linear(128, 128, bias=False)
    (W2): Linear(128, 128, bias=False)
  )
)


In [30]:
next(iter(train_loader))

HeteroData(
  user={
    num_nodes=353,
    n_id=[353],
  },
  movie={
    num_nodes=725,
    n_id=[725],
  },
  (user, rates, movie)={
    edge_index=[2, 85],
    e_id=[85],
    input_id=[1000],
    edge_label_index=[2, 1000],
  }
)

In [31]:
print(next(iter(train_loader))['user','rates','movie']['edge_label_index'])

tensor([[ 54, 248,  98,  ..., 257,  36, 248],
        [629, 159, 483,  ..., 429,  48, 501]])


In [32]:
train_losses = []
val_losses = []

from tqdm import tqdm
from torch_geometric.utils import structured_negative_sampling

torch.autograd.set_detect_anomaly(True)

for step in tqdm(range(ITERATIONS)):
    train_data = next(iter(train_loader))
    users_emb_final, users_emb_0, items_emb_final, items_emb_0 = model.forward(
       train_data, debug=False)

    print(f"user emb final: {users_emb_final}")

    # print(f"Sampling started ...")
    sampling = structured_negative_sampling(train_data['user','rates','movie']['edge_label_index'])
    # print(f"Sampling completed...")

    user_indices = sampling[0]
    pos_item_indices = sampling[1]
    neg_item_indices = sampling[2]

    # print(f"user_indices size: {user_indices.size()}")
    # print(f"items_emb_final size: {items_emb_final.size()}")
    # print(f"pos_item_indices size: {pos_item_indices.size()}")
    # print(f"neg_item_indices size: {neg_item_indices.size()}")

    users_emb_final, users_emb_0 = users_emb_final[user_indices], users_emb_0[user_indices]
    pos_items_emb_final, pos_items_emb_0 = items_emb_final[pos_item_indices], items_emb_0[pos_item_indices]
    neg_items_emb_final, neg_items_emb_0 = items_emb_final[neg_item_indices], items_emb_0[neg_item_indices]

    train_loss = bpr_loss(users_emb_final, users_emb_0, pos_items_emb_final,
                          pos_items_emb_0, neg_items_emb_final, neg_items_emb_0, LAMBDA)

    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    if step % ITERS_PER_EVAL == 0:
        model.eval()
        val_data = next(iter(val_loader))
        val_loss, recall, precision, ndcg = evaluation(
            # model, edge_index, sparse_edge_index, exclude_edge_indices, k, lambda_val
            model, val_data['user','rates','movie']['edge_label_index'], val_data, [train_edge_index], K, LAMBDA)
        print(f"[Iteration {step}/{ITERATIONS}] train_loss: {round(train_loss.item(), 5)}, val_loss: {round(val_loss, 5)}, val_recall@{K}: {round(recall, 5)}, val_precision@{K}: {round(precision, 5)}, val_ndcg@{K}: {round(ndcg, 5)}")
        train_losses.append(train_loss.item())
        val_losses.append(val_loss)
        model.train()

    # if step % ITERS_PER_LR_DECAY == 0 and step != 0:
    #    scheduler.step()

  0%|          | 0/100 [00:00<?, ?it/s]

user emb final: tensor([[ 1.1710e-19,  1.3563e-19,  4.6135e+24,  ..., -5.5334e-04,
         -1.7759e-03,  1.2467e-01],
        [ 4.6135e+24,  4.1723e-08,  1.7117e-10,  ...,  3.0934e-02,
         -3.1482e-04,  6.3413e-02],
        [ 1.7118e-10,  9.1678e-33,  1.3563e-19,  ..., -2.7094e-03,
          2.5165e-01, -4.8617e-04],
        ...,
        [ 1.7117e-10,  9.1671e-33,  1.3563e-19,  ..., -8.5050e-05,
         -6.3595e-03,  1.1774e-01],
        [ 1.1060e-02,  2.2310e-01,  1.3563e-19,  ...,  8.0087e-02,
         -1.1108e-04,  9.6172e-02],
        [ 1.0741e-05,  1.8018e+22,  1.3556e-19,  ...,  1.3924e-02,
         -1.7881e-03,  1.1394e-01]], grad_fn=<SliceBackward0>)


  File "/Users/robmayo/miniconda3/envs/gnn-ass-env-cpu/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/robmayo/miniconda3/envs/gnn-ass-env-cpu/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/robmayo/miniconda3/envs/gnn-ass-env-cpu/lib/python3.8/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/robmayo/miniconda3/envs/gnn-ass-env-cpu/lib/python3.8/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/robmayo/miniconda3/envs/gnn-ass-env-cpu/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/Users/robmayo/miniconda3/envs/gnn-ass-env-cpu/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/Users/robmayo/miniconda3/envs/gnn-ass-env-cpu/lib/python3.8/

RuntimeError: Function 'SoftplusBackward0' returned nan values in its 0th output.