In [107]:
!pip install torch_geometric



In [108]:
from torch_geometric.nn.models.lightgcn import LightGCN
import pandas as pd
import os
from tqdm import tqdm
import torch
import numpy as np

## Load Data
We can begin by loading in the user review data. For each user, we have a subset of the movies that they reviewed. We'll load each of the CSVs as dataframes, and store a dict of user IDs corresponding to their dataframes.

In [109]:
# for now we will use the first 10k rows of the data, set to None to use all data
AMOUNT_TO_LOAD = None

In [110]:
user_reviews_dir = 'user_reviews'
user_review_data = dict()

for filename in tqdm(os.listdir(user_reviews_dir)):
    if AMOUNT_TO_LOAD is not None and len(user_review_data) >= AMOUNT_TO_LOAD:
        break
    try:
        user_review_data[filename] = pd.read_csv(os.path.join(user_reviews_dir, filename), encoding='unicode_escape')
    except pd.errors.EmptyDataError:
        print(f'Empty file: {filename}')
        pass

  1%|          | 431/63111 [00:00<01:41, 615.74it/s]

Empty file: 468889434_reviews.csv


  4%|▍         | 2696/63111 [00:04<01:39, 606.82it/s]

Empty file: alinetta_reviews.csv


  9%|▊         | 5520/63111 [00:08<01:29, 644.34it/s]

Empty file: austinsiemens_reviews.csv


 18%|█▊        | 11183/63111 [00:18<01:15, 685.78it/s]

Empty file: chisvy_reviews.csv


 20%|██        | 12851/63111 [00:20<01:27, 576.57it/s]

Empty file: critics_said_reviews.csv


 23%|██▎       | 14460/63111 [00:28<03:55, 206.69it/s]

Empty file: demeguajara_reviews.csv


 25%|██▍       | 15684/63111 [00:34<03:59, 198.37it/s]

Empty file: dragospal_reviews.csv


 25%|██▌       | 15885/63111 [00:36<04:17, 183.26it/s]

Empty file: ds612_reviews.csv


 27%|██▋       | 16975/63111 [00:41<03:38, 211.10it/s]

Empty file: elinesophie_reviews.csv


 36%|███▌      | 22608/63111 [01:10<03:19, 203.28it/s]

Empty file: gypsydoll_reviews.csv


 40%|███▉      | 25018/63111 [01:22<03:05, 205.31it/s]

Empty file: iloveflorida1_reviews.csv


 41%|████      | 25682/63111 [01:25<02:51, 217.82it/s]

Empty file: iskndrz_reviews.csv


 47%|████▋     | 29577/63111 [01:45<02:34, 216.68it/s]

Empty file: juju108_reviews.csv


 54%|█████▍    | 33928/63111 [02:07<02:16, 214.24it/s]

Empty file: lilacat_reviews.csv


 55%|█████▍    | 34587/63111 [02:10<02:17, 208.19it/s]

Empty file: lolahowls911_reviews.csv


 61%|██████    | 38499/63111 [02:30<02:03, 198.71it/s]

Empty file: mediocremedia_reviews.csv


 68%|██████▊   | 42901/63111 [02:52<01:40, 200.56it/s]

Empty file: nicteeee_reviews.csv


 82%|████████▏ | 51632/63111 [03:36<00:57, 198.63it/s]

Empty file: sebastian823_reviews.csv


 94%|█████████▍| 59422/63111 [04:15<00:18, 200.01it/s]

Empty file: veera0304_reviews.csv


100%|██████████| 63111/63111 [04:35<00:00, 229.49it/s]


Now let's split the data into training, validation, and test sets. Since this is a recommender, we're gonna split by removing some of the user's reviews.

For every user, so long as the user has more than 5 reviews, remove one review for the validation set and one review for the test set.

In [111]:
print(list(user_review_data.keys())[0])

0001kidd_reviews.csv


In [112]:
# remove all values with nan in the review column
for key in tqdm(user_review_data.keys()):
    user_review_data[key] = user_review_data[key].dropna(subset=['movie_rating'])

100%|██████████| 63092/63092 [00:49<00:00, 1270.19it/s]


In [113]:
train_reviews = []
validation_reviews = []
test_reviews = []
for user_id, reviews in tqdm(user_review_data.items()):
    if len(reviews) > 50:
        validation_review_data_df = reviews.sample(15, replace=False)
        validation_review_data = validation_review_data_df.to_dict('records')
        for review in validation_review_data:
            review['user_id'] = user_id
        validation_reviews.extend(validation_review_data)
        # remove the validation reviews from the training data
        reviews = reviews.drop(validation_review_data_df.index)
        test_review_data_df = reviews.sample(15, replace=False)
        test_review_data = test_review_data_df.to_dict('records')
        for review in test_review_data:
            review['user_id'] = user_id
        test_reviews.extend(test_review_data)
        # remove the test reviews from the training data
        reviews = reviews.drop(test_review_data_df.index)
        train_review_data = reviews.to_dict('records')
        for review in train_review_data:
            review['user_id'] = user_id
        train_reviews.extend(train_review_data)
    else:
        # if the user has less than 5 reviews, we will use all of them for training
        train_review_data = reviews.to_dict('records')
        for review in train_review_data:
            review['user_id'] = user_id
        train_reviews.extend(train_review_data)

print(f'Train reviews: {len(train_reviews)}')
print(f'Validation reviews: {len(validation_reviews)}')
print(f'Test reviews: {len(test_reviews)}')

100%|██████████| 63092/63092 [02:56<00:00, 358.45it/s]

Train reviews: 25466161
Validation reviews: 751305
Test reviews: 751305





## Build the Model
Now that we have the training data, let's construct the model to train.

In [114]:
num_train_users = len(set([review['user_id'] for review in train_reviews]))
num_train_items = len(set([review['movie_id'] for review in train_reviews]))
num_total_items = len(set([review['movie_id'] for review in train_reviews + validation_reviews + test_reviews]))
num_nodes = num_train_users + num_total_items
print(f'Number of train users: {num_train_users}')
print(f'Number of train items: {num_train_items}')
print(f'Number of nodes: {num_nodes}')

Number of train users: 63087
Number of train items: 288276
Number of nodes: 353467


In [115]:
num_val_users = len(set([review['user_id'] for review in validation_reviews]))
num_val_items = len(set([review['movie_id'] for review in validation_reviews]))
num_val_nodes = num_val_users + num_val_items

In [116]:
# Let's map users to ids
movie_id_to_movie_name = dict()
for review in train_reviews + validation_reviews + test_reviews:
    movie_id_to_movie_name[review['movie_id']] = review['movie_title']

user_to_id = dict()
for i, user_id in enumerate(set([review['user_id'] for review in train_reviews + validation_reviews + test_reviews])):
    user_to_id[user_id] = i

# Let's map movies to ids
movie_to_id = dict()
for i, movie_id in enumerate(set([review['movie_id'] for review in train_reviews + validation_reviews + test_reviews])):
    movie_to_id[movie_id] = i + num_train_users

# Let's map ids to users
id_to_user = dict()
for user_id, index in user_to_id.items():
    id_to_user[index] = user_id

# Let's map ids to movies
id_to_movie = dict()
for movie_id, index in movie_to_id.items():
    id_to_movie[index] = movie_id

# Let's map movie names to movie ids
movie_name_to_movie_id = dict()
for movie_id, movie_name in movie_id_to_movie_name.items():
    movie_name_to_movie_id[movie_name] = movie_id

In [117]:
import random

def convert_review_to_edge(review):
    user_id = user_to_id[review['user_id']]
    movie_id = movie_to_id[review['movie_id']]
    edge_weight = review['movie_rating']
    if (edge_weight < 3.5 and edge_weight > 2.5):
        return None, None
    edge = (user_id, movie_id)
    edge_weight = review['movie_rating']
    return edge, edge_weight

def shuffle_edges_and_edge_weights(edges, edge_weights):
    c = list(zip(edges, edge_weights))
    random.shuffle(c)
    return zip(*c)

def convert_reviews_to_edges(reviews):
    edges = []
    edge_weights = []
    for review in tqdm(reviews):
        edge, edge_weight = convert_review_to_edge(review)
        if edge is not None:
            edges.append(edge)
            edge_weights.append(edge_weight)
    
    # Reformat the edges to be a tensor
    edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edges, edge_weights

In [118]:
# Now let's create the edges between users and movies.
# The id of the user will be the index of the user in the user_to_id dict
# The id of the movie will be the index of the movie in the movie_to_id dict + the number of users

train_edges, train_edge_weights = convert_reviews_to_edges(train_reviews)
validation_edges, validation_edge_weights = convert_reviews_to_edges(validation_reviews)

print(f'Train edges: {train_edges.shape[1]}')
print(f'Validation edges: {validation_edges.shape[1]}')

100%|██████████| 25466161/25466161 [00:23<00:00, 1090616.23it/s]
100%|██████████| 751305/751305 [00:00<00:00, 1000828.49it/s]


Train edges: 20418208
Validation edges: 621225


In [119]:
import torch_geometric.data as data

# create the graph
train_graph = data.Data(
    edge_index=train_edges,
    edge_attr=torch.tensor(train_edge_weights),
    num_nodes=num_nodes
)

validation_graph = data.Data(
    edge_index=validation_edges,
    edge_attr=torch.tensor(validation_edge_weights),
    num_nodes=num_nodes
)

In [120]:
train_graph.validate(raise_on_error=True)
validation_graph.validate(raise_on_error=True)

True

In [121]:
# Let's create some negative edges
def resample_edges_for_user(user_positive_edges, user_negative_edges):
    num_negative_edges_to_add = user_positive_edges.shape[1] * 3 - user_negative_edges.shape[1]
    if (num_negative_edges_to_add <= 0):
        num_negative_edges_to_remove = -num_negative_edges_to_add
        # choose the negative edges to keep
        negative_edges_to_keep = torch.randint(user_negative_edges.shape[1], (user_negative_edges.shape[1] - num_negative_edges_to_remove,))
        # remove all the negative edges for this user
        user_negative_edges = user_negative_edges[:, negative_edges_to_keep]
    else:
        # Create new negative edges
        negative_edges_to_add = torch.tensor([[user_id] * num_negative_edges_to_add, torch.randint(num_train_users, num_train_items, (num_negative_edges_to_add,))], dtype=torch.long)
        # Add the negative edges to the negative edges for this user
        user_negative_edges = torch.cat([user_negative_edges, negative_edges_to_add], dim=1)
    return user_positive_edges, user_negative_edges
        

In [122]:
# let's compute ndcg
def compute_ndcg_at_k(relevances, k=5):
    relevances = relevances[:k]
    dcg = 0
    for i, relevance in enumerate(relevances):
        dcg += (2 ** relevance - 1) / np.log2(i + 2)
    idcg = 0
    for i, relevance in enumerate(sorted(relevances, reverse=True)):
        idcg += (2 ** relevance - 1) / np.log2(i + 2)
    return dcg / idcg

In [123]:
def get_user_positive_items(edge_index):
    """Generates dictionary of positive items for each user

    Args:
        edge_index (torch.Tensor): 2 by N list of edges

    Returns:
        dict: dictionary of positive items for each user
    """
    user_pos_items = {}
    for i in range(edge_index.shape[1]):
        user = edge_index[0][i].item()
        item = edge_index[1][i].item()
        if user not in user_pos_items:
            user_pos_items[user] = []
        user_pos_items[user].append(item)
    return user_pos_items

In [124]:
import time
def compute_recall_at_k(validation_graph, model, K):
    # get positive edges in validation set
    positive_edges = validation_graph.edge_index[:, validation_graph.edge_attr > 3.5]

    # map users to positive edges
    user_pos_items = get_user_positive_items(positive_edges)

    # get users
    users = positive_edges[0].unique()

    users = users[torch.randint(users.shape[0], (min(200, len(users)),))]
    # filter the validation edges to only the users we want to evaluate
    user_validation_edges = []
    for user in users:
        user_validation_edges.append(validation_graph.edge_index[:, validation_graph.edge_index[0] == user])
    user_validation_edges = torch.cat(user_validation_edges, dim=1)
    print(user_validation_edges.shape)

    first_user_id = users[0].item()
    user_name = id_to_user[first_user_id]
    print(f'User: {user_name}')

    # get movies
    movie_indices = torch.LongTensor([_ for _ in range(len(users) + 1, validation_graph.num_nodes)]).to(device)

    # Get positive items for each user in validation set
    truth_items = [set(user_pos_items[user.item()]) for user in users]

    first_user_truth_items = truth_items[0]
    first_user_truth_items = [id_to_movie[item] for item in first_user_truth_items]
    first_user_truth_items = [movie_id_to_movie_name[item] for item in first_user_truth_items]
    print(first_user_truth_items)

    training_edges = train_graph.edge_index

    # Get top-K recommended items for each user in validation set
    total_recall = 0
    print("Computing recommendations for {} users".format(len(users)))
    for user_index, user_id in tqdm(enumerate(users), total=len(users)):
        tick = time.time()
        all_edges = torch.tensor([(user_id, item_id) for item_id in range(num_train_users, num_train_items)], dtype=torch.long).t().contiguous()
        recommendations = model.recommend(all_edges.to(device), src_index=torch.tensor([user_id]), dst_index=torch.tensor([x for x in range(num_train_users + 1, num_train_items)]), k=3 * K)[0]
        tock = time.time()
        train_edges_for_user = training_edges[:, training_edges[0] == user_id]
        # remove all the recommendations that are in the training set
        recommendations = recommendations[~torch.isin(recommendations, train_edges_for_user[1])][:K]
        if (len(recommendations) < K):
            print("Not enough recommendations for user {}".format(user_id))
        if (user_id == first_user_id):
            first_user_recommended_items = recommendations
            first_user_recommended_items = [id_to_movie[item.item()] for item in first_user_recommended_items if item.item() > num_train_users]
            first_user_recommended_items = [movie_id_to_movie_name[item] for item in first_user_recommended_items if item in movie_id_to_movie_name]
            print(first_user_recommended_items)
        # num_intersect = 0
        truth_items_for_user = truth_items[user_index]
        # for item in recommendations:
        #     item = item.item()
        #     if item in truth_items_for_user:
        #         num_intersect += 1
        # print(num_intersect)
        num_intersect = len(set([item.item() for item in recommendations]).intersection(truth_items[user_index]))
        recall = num_intersect / len(truth_items_for_user)
        total_recall += recall
    return total_recall / len(users)



In [125]:
from typing import Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding, ModuleList
from torch.nn.modules.loss import _Loss

from torch_geometric.nn.conv import LGConv
from torch_geometric.typing import Adj, OptTensor, SparseTensor

In [126]:
"""Adapted from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/lightgcn.html"""
class CustomLightGCN(torch.nn.Module):
    """From the <https://arxiv.org/abs/2002.02126>` paper.

    Args:
        num_nodes (int): The number of nodes in the graph.
        embedding_dim (int): The dimensionality of node embeddings.
        num_layers (int): The number of layers.
    """
    def __init__(
        self,
        num_nodes: int,
        embedding_dim: int,
        num_layers: int
    ):
        super().__init__()

        self.num_nodes = num_nodes
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.embedding = Embedding(num_nodes, embedding_dim)
        self.alpha = torch.tensor([1. / (num_layers + 1)] * (num_layers + 1))
        self.convs = ModuleList([GATConv(embedding_dim, embedding_dim, heads=8, dropout=0.6) for _ in range(num_layers)])
        self.linears = ModuleList([Linear(embedding_dim * 8, embedding_dim) for _ in range(num_layers)])
        torch.nn.init.xavier_uniform_(self.embedding.weight)

    def get_embedding(self, edge_index):
        x = self.embedding.weight
        out = x * self.alpha[0]

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = self.linears[i](x.view(-1, self.embedding_dim * 8))
            out = out + x * self.alpha[i + 1]

        return out


    def forward(self, edge_index):
        edge_label_index = edge_index
        out = self.get_embedding(edge_index)
        user = out[edge_label_index[0]]
        movie = out[edge_label_index[1]]
        return (user * movie).sum(dim=-1)


    def predict_link(self, edge_index, edge_label_index):
        "Predict links between nodes specified in edge_label_index."""
        pred = self(edge_index, edge_label_index).sigmoid()
        return pred.round()


    def recommend(self, edge_index, k):
        """Get top-k recommendations for nodes in src_index."""
        out_user = self.get_embedding(edge_index)
        out_movie = self.get_embedding(edge_index)
        pred = out_user @ out_movie.t()
        top_index = pred.topk(k, dim=-1).indices
        return top_index


    def link_pred_loss(self, pred, edge_label):
        """Computes the model loss for a link prediction using torch.nn.BCEWithLogitsLoss.
        
        Args:
            pred (torch.Tensor): The predictions.
            edge_label (torch.Tensor): The ground-truth edge labels.
        """
        loss_fn = torch.nn.BCEWithLogitsLoss()
        return loss_fn(pred, edge_label.to(pred.dtype))


    def recommendation_loss(self, pos_edge_rank, neg_edge_rank,
                            lambda_reg: float = 1e-4):
        """Computes the model loss for a ranking objective via the Bayesian
        Personalized Ranking (BPR) loss.

        Args:
            pos_edge_rank (torch.Tensor): Positive edge rankings.
            neg_edge_rank (torch.Tensor): Negative edge rankings.
            lambda_reg (int, optional): The L2 regularization strength
                of the Bayesian Personalized Ranking (BPR) loss.
        """
        loss_fn = BPRLoss(lambda_reg)
        return loss_fn(pos_edge_rank, neg_edge_rank, self.embedding.weight)

In [127]:
""" This is verbatim from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/lightgcn.html. """
class BPRLoss(_Loss):
    """The Bayesian Personalized Ranking (BPR) loss."""
    __constants__ = ['lambda_reg']
    lambda_reg: float

    def __init__(self, lambda_reg: float = 0, **kwargs):
        super().__init__(None, None, "sum", **kwargs)
        self.lambda_reg = 0

    def forward(self, positives: Tensor, negatives: Tensor,
                parameters: Tensor = None) -> Tensor:
        """Compute the mean Bayesian Personalized Ranking (BPR) loss.

        Args:
            positives (Tensor): The vector of positive-pair rankings.
            negatives (Tensor): The vector of negative-pair rankings.
            parameters (Tensor, optional): The tensor of parameters which
                should be used for :math:`L_2` regularization
                (default: :obj:`None`).
        """
        n_pairs = positives.size(0)
        log_prob = F.logsigmoid(positives - negatives).mean()
        regularization = 0

        if self.lambda_reg != 0:
            regularization = self.lambda_reg * parameters.norm(p=2).pow(2)

        return (-log_prob + regularization) / n_pairs

In [129]:
import numpy as np
import math
import matplotlib.pyplot as plt

NUM_LAYERS = 1
LR = 5e-5
BATCH_SIZE = min(128, len(user_review_data))
EMBEDDING_DIM = 64
K = 10
model = LightGCN(num_nodes=num_nodes, embedding_dim=EMBEDDING_DIM, num_layers=NUM_LAYERS)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print("Running on device: {}".format(device))
print(EMBEDDING_DIM)

optim = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.95)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[100, 200, 300, 400], gamma=0.5)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=100)

train_positive_edges = train_graph.edge_index[:, train_graph.edge_attr >= 3.5]
train_negative_edges = train_graph.edge_index[:, train_graph.edge_attr <= 2.5]

validation_df = pd.DataFrame.from_dict(validation_reviews)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(comment=f'LightGCN_{EMBEDDING_DIM}_layers_{NUM_LAYERS}_batch_size_{BATCH_SIZE}_lr_{LR}_num_train_users_{num_train_users}_num_train_items_{num_train_items}_recall_{K}')

for epoch in range(10001):
    # we are using BPR so we go by user
    average_loss = 0
    # We'll proceed in batches of users
    for start_idx in tqdm(range(0, num_train_users, BATCH_SIZE)):
        model.train()
        loss = torch.tensor(0.0, requires_grad=True)
        # randomly select a batch of users
        users_in_batch = torch.randperm(num_train_users)[start_idx:start_idx + BATCH_SIZE]
        for user_id in users_in_batch:
            # get all the edges specific to this user
            user_positive_edges = train_positive_edges[:, train_positive_edges[0] == user_id]
            user_negative_edges = train_negative_edges[:, train_negative_edges[0] == user_id]
            if (user_positive_edges.shape[1] == 0 or user_negative_edges.shape[1] == 0):
                continue
            # limit the number of positive edges to 5000
            if (user_positive_edges.shape[1] > 5000):
                user_positive_edges = user_positive_edges[:, :5000]
            # Get at most 15000 negative edges
            if (user_negative_edges.shape[1] > 15000):
                user_negative_edges = user_negative_edges[:, :15000]
            # resample the negative edges if we don't have enough
            user_positive_edges, user_negative_edges = resample_edges_for_user(user_positive_edges, user_negative_edges)
            # concatenate the positive and negative edges
            user_edges = torch.cat([user_positive_edges, user_negative_edges], dim=1)
            # get the rankings for this user
            user_edges = user_edges.to(device)
            user_rankings = model(user_edges)
            # divide the rankings into positive and negative rankings
            user_positive_rankings = user_rankings[:user_positive_edges.shape[1]]
            user_negative_rankings = user_rankings[user_positive_edges.shape[1]:]
            # create all pairs of positive and negative rankings
            user_positive_rankings = user_positive_rankings.unsqueeze(1).repeat(1, user_negative_rankings.shape[0])
            user_negative_rankings = user_negative_rankings.unsqueeze(0).repeat(user_positive_rankings.shape[0], 1)
            # get the user loss
            user_loss = model.recommendation_loss(user_positive_rankings, user_negative_rankings, 1e-4)
            # add the user loss to the total loss
            loss = loss + user_loss
        # divide the loss by the number of users
        loss = loss / BATCH_SIZE
        # log the loss
        # backprop
        optim.zero_grad()
        loss.backward()
        optim.step()
        writer.add_scalar("Loss/train", loss, epoch * (num_train_users // BATCH_SIZE) + start_idx // BATCH_SIZE)
        print(epoch * BATCH_SIZE + start_idx // BATCH_SIZE)
        average_loss = 0
        if (epoch * BATCH_SIZE + start_idx // BATCH_SIZE) % 100 == 0:
            # evaluate the model
            model.eval()
            # iterate over all users in the validation set
            validation_users = list(set([int(x) for x in validation_edges[0, :]]))
            # randomly select 1000 of the users
            validation_users = random.sample(validation_users, min(len(validation_users), 500))
            mean_ndcg = 0
            ndcg_scores = []
            for user in tqdm(validation_users):
                user_id = id_to_user[user]
                relevant_reviews = validation_df[validation_df['user_id'] == user_id]
                user_validation_edges = validation_edges[:, validation_edges[0] == user]
                user_validation_edges = user_validation_edges.to(device)
                user_rankings = model(user_validation_edges)
                edges_sorted = list(user_validation_edges[1, user_rankings.argsort(descending=True)])
                # use validation_df to get the relevances via the movie_id column and the movie_rating column
                relevances = []
                for edge in edges_sorted:
                    movie_id = id_to_movie[int(edge)]
                    if (movie_id in relevant_reviews['movie_id'].values):
                        relevances.append(relevant_reviews[relevant_reviews['movie_id'] == movie_id]['movie_rating'].values[0])
                    else:
                        relevances.append(0)
                # calculate the ndcg
                ndcg = compute_ndcg_at_k(relevances)
                if (math.isnan(ndcg)):
                    print(relevant_reviews)
                    input()
                mean_ndcg += ndcg
                ndcg_scores.append(ndcg)
            mean_ndcg = mean_ndcg / len(validation_users)
            print("Standard Deviation: {}".format(np.std(ndcg_scores)))
            # create a histogram of the ndcg scores, make bins for each 0.1
            ndcg_scores = np.array(ndcg_scores).squeeze()
            writer.add_histogram("hist_NDCG/val", ndcg_scores, epoch)
            # also make a histogram in matplotlib and save as png
            plt.hist(ndcg_scores, bins=np.arange(0, 1.1, 0.1))
            plt.suptitle("Validation NDCG Histogram")
            # write information about the model to the histogram
            plt.title(f"Model: LightGCN, Embedding Dim: {EMBEDDING_DIM}, Num Layers: {NUM_LAYERS}, Batch Size: {BATCH_SIZE}, LR: {LR}, Num Train Users: {num_train_users}, Num Train Items: {num_train_items}", fontsize=8, wrap=True)
            plt.xlabel("NDCG")
            plt.ylabel("Frequency")
            # save the figure in the hist_NDCG folder, with the title having the model information and the epoch number
            plt.savefig(f"hist_NDCG/val_{EMBEDDING_DIM}_{NUM_LAYERS}_{BATCH_SIZE}_{LR}_{num_train_users}_{num_train_items}_{epoch}.png")
            plt.close()
            # Also save the raw NDCG scores to a csv file, with the model information in the title, and the epoch number
            np.savetxt(f"hist_NDCG/val_{EMBEDDING_DIM}_{NUM_LAYERS}_{BATCH_SIZE}_{LR}_{num_train_users}_{num_train_items}_{epoch}.csv", ndcg_scores, delimiter=",")
            print(mean_ndcg)
            writer.add_scalar("NDCG/val", mean_ndcg, epoch * (num_train_users // BATCH_SIZE) + start_idx // BATCH_SIZE)
            recall_at_k = compute_recall_at_k(validation_graph, model, K)
            writer.add_scalar("Recall@K/val", recall_at_k, epoch * (num_train_users // BATCH_SIZE) + start_idx // BATCH_SIZE)
            print("Epoch: {}, NDCG: {}, Recall@{}: {}".format(epoch, mean_ndcg, K, recall_at_k))
            average_number_of_matches = 0
            for user_id in validation_users:
                all_edges = torch.tensor([(user_id, item_id) for item_id in range(num_train_users, num_train_items)], dtype=torch.long).t().contiguous()
                recommendations = model.recommend(all_edges.to(device), src_index=torch.tensor([user_id]), dst_index=torch.tensor([x for x in range(num_train_users + 1, num_train_items)]), k=10)[0]
                movie_names = [movie_id_to_movie_name[id_to_movie[int(recommendation)]] for recommendation in recommendations]
                true_user_reviews = user_review_data[id_to_user[user_id]]
                matches = 0
                for movie_name in movie_names:
                    if movie_name in true_user_reviews['movie_title'].values:
                        matches += 1
                average_number_of_matches += matches
            average_number_of_matches = average_number_of_matches / len(validation_users)
            print("Average number of matches: {}".format(average_number_of_matches))
            writer.add_scalar("Average number of matches", average_number_of_matches, epoch * (num_train_users // BATCH_SIZE) + start_idx // BATCH_SIZE)
            print("=====================================")

Running on device: cuda
64


  0%|          | 1/493 [00:03<25:24,  3.10s/it]

0


  0%|          | 2/493 [00:05<23:37,  2.89s/it]

1


  1%|          | 3/493 [00:08<22:27,  2.75s/it]

2


  1%|          | 4/493 [00:11<21:57,  2.69s/it]

3


  1%|          | 5/493 [00:13<22:27,  2.76s/it]

4


  1%|          | 6/493 [00:18<27:57,  3.45s/it]

5


  1%|▏         | 7/493 [00:21<26:02,  3.21s/it]

6


  2%|▏         | 8/493 [00:24<24:40,  3.05s/it]

7


  2%|▏         | 9/493 [00:26<23:52,  2.96s/it]

8


  2%|▏         | 10/493 [00:29<23:08,  2.88s/it]

9


  2%|▏         | 11/493 [00:32<22:46,  2.84s/it]

10


  2%|▏         | 12/493 [00:35<22:28,  2.80s/it]

11


  3%|▎         | 13/493 [00:37<22:25,  2.80s/it]

12


  3%|▎         | 14/493 [00:40<21:57,  2.75s/it]

13


  3%|▎         | 15/493 [00:43<22:04,  2.77s/it]

14


  3%|▎         | 16/493 [00:45<21:46,  2.74s/it]

15


  3%|▎         | 17/493 [00:48<21:33,  2.72s/it]

16


  4%|▎         | 18/493 [00:51<21:18,  2.69s/it]

17


  4%|▍         | 19/493 [00:53<21:04,  2.67s/it]

18


  4%|▍         | 20/493 [00:56<21:22,  2.71s/it]

19


  4%|▍         | 21/493 [00:59<21:19,  2.71s/it]

20


  4%|▍         | 22/493 [01:02<21:20,  2.72s/it]

21


  5%|▍         | 23/493 [01:04<21:17,  2.72s/it]

22


  5%|▍         | 24/493 [01:07<21:17,  2.72s/it]

23


  5%|▌         | 25/493 [01:10<21:11,  2.72s/it]

24


  5%|▌         | 26/493 [01:12<20:55,  2.69s/it]

25


  5%|▌         | 27/493 [01:15<21:10,  2.73s/it]

26


  6%|▌         | 28/493 [01:18<21:21,  2.76s/it]

27


  6%|▌         | 29/493 [01:21<21:24,  2.77s/it]

28


  6%|▌         | 30/493 [01:24<21:12,  2.75s/it]

29


  6%|▋         | 31/493 [01:26<20:59,  2.73s/it]

30


  6%|▋         | 32/493 [01:29<20:46,  2.70s/it]

31


  7%|▋         | 33/493 [01:32<20:38,  2.69s/it]

32


  7%|▋         | 34/493 [01:34<20:28,  2.68s/it]

33


  7%|▋         | 35/493 [01:37<20:46,  2.72s/it]

34


  7%|▋         | 36/493 [01:40<20:43,  2.72s/it]

35


  8%|▊         | 37/493 [01:42<20:27,  2.69s/it]

36


  8%|▊         | 38/493 [01:45<20:20,  2.68s/it]

37


  8%|▊         | 39/493 [01:48<20:09,  2.67s/it]

38


  8%|▊         | 40/493 [01:50<20:14,  2.68s/it]

39


  8%|▊         | 41/493 [01:53<20:30,  2.72s/it]

40


  9%|▊         | 42/493 [01:56<20:41,  2.75s/it]

41


  9%|▊         | 43/493 [01:59<20:29,  2.73s/it]

42


  9%|▉         | 44/493 [02:01<20:26,  2.73s/it]

43


  9%|▉         | 45/493 [02:04<20:20,  2.72s/it]

44


  9%|▉         | 46/493 [02:07<20:00,  2.68s/it]

45


 10%|▉         | 47/493 [02:09<20:02,  2.70s/it]

46


 10%|▉         | 48/493 [02:12<19:43,  2.66s/it]

47


 10%|▉         | 49/493 [02:15<19:41,  2.66s/it]

48


 10%|█         | 50/493 [02:17<19:39,  2.66s/it]

49


 10%|█         | 51/493 [02:20<19:34,  2.66s/it]

50


 11%|█         | 52/493 [02:23<19:21,  2.63s/it]

51


 11%|█         | 53/493 [02:25<19:39,  2.68s/it]

52


 11%|█         | 54/493 [02:28<19:37,  2.68s/it]

53


 11%|█         | 55/493 [02:31<19:41,  2.70s/it]

54


 11%|█▏        | 56/493 [02:33<19:38,  2.70s/it]

55


 12%|█▏        | 57/493 [02:36<19:31,  2.69s/it]

56


 12%|█▏        | 58/493 [02:39<19:23,  2.67s/it]

57


 12%|█▏        | 59/493 [02:41<19:26,  2.69s/it]

58


 12%|█▏        | 60/493 [02:44<19:42,  2.73s/it]

59


 12%|█▏        | 61/493 [02:47<19:35,  2.72s/it]

60


 13%|█▎        | 62/493 [02:50<19:31,  2.72s/it]

61


 13%|█▎        | 63/493 [02:53<19:49,  2.77s/it]

62


 13%|█▎        | 64/493 [02:55<19:23,  2.71s/it]

63


 13%|█▎        | 65/493 [03:00<23:05,  3.24s/it]

64


 13%|█▎        | 66/493 [03:02<21:46,  3.06s/it]

65


 14%|█▎        | 67/493 [03:05<20:54,  2.95s/it]

66


 14%|█▍        | 68/493 [03:08<20:10,  2.85s/it]

67


 14%|█▍        | 69/493 [03:10<19:39,  2.78s/it]

68


 14%|█▍        | 70/493 [03:13<19:09,  2.72s/it]

69


 14%|█▍        | 71/493 [03:15<18:59,  2.70s/it]

70


 15%|█▍        | 72/493 [03:18<18:50,  2.69s/it]

71


 15%|█▍        | 73/493 [03:21<18:43,  2.67s/it]

72


 15%|█▌        | 74/493 [03:24<19:20,  2.77s/it]

73


 15%|█▌        | 75/493 [03:26<18:48,  2.70s/it]

74


 15%|█▌        | 76/493 [03:29<18:36,  2.68s/it]

75


 16%|█▌        | 77/493 [03:32<18:26,  2.66s/it]

76


 16%|█▌        | 78/493 [03:34<18:35,  2.69s/it]

77


 16%|█▌        | 79/493 [03:37<18:28,  2.68s/it]

78


 16%|█▌        | 80/493 [03:39<18:07,  2.63s/it]

79


 16%|█▋        | 81/493 [03:42<18:18,  2.67s/it]

80


 17%|█▋        | 82/493 [03:45<18:08,  2.65s/it]

81


 17%|█▋        | 83/493 [03:48<18:10,  2.66s/it]

82


 17%|█▋        | 84/493 [03:50<18:01,  2.64s/it]

83


 17%|█▋        | 85/493 [03:53<18:09,  2.67s/it]

84


 17%|█▋        | 86/493 [03:55<17:53,  2.64s/it]

85


 18%|█▊        | 87/493 [03:58<17:38,  2.61s/it]

86


 18%|█▊        | 88/493 [04:01<17:44,  2.63s/it]

87


 18%|█▊        | 89/493 [04:03<17:55,  2.66s/it]

88


 18%|█▊        | 90/493 [04:07<18:52,  2.81s/it]

89


 18%|█▊        | 91/493 [04:09<18:57,  2.83s/it]

90


 19%|█▊        | 92/493 [04:12<18:48,  2.81s/it]

91


 19%|█▉        | 93/493 [04:15<18:25,  2.76s/it]

92


 19%|█▉        | 94/493 [04:18<18:20,  2.76s/it]

93


 19%|█▉        | 95/493 [04:20<18:10,  2.74s/it]

94


 19%|█▉        | 96/493 [04:23<18:16,  2.76s/it]

95


 20%|█▉        | 97/493 [04:26<18:12,  2.76s/it]

96


 20%|█▉        | 98/493 [04:29<18:08,  2.75s/it]

97


 20%|██        | 99/493 [04:31<17:58,  2.74s/it]

98


 20%|██        | 100/493 [04:34<18:04,  2.76s/it]

99


 20%|██        | 101/493 [04:37<17:55,  2.74s/it]

100


 21%|██        | 102/493 [04:39<17:44,  2.72s/it]

101


 21%|██        | 103/493 [04:42<17:35,  2.71s/it]

102


 21%|██        | 104/493 [04:45<17:28,  2.69s/it]

103


 21%|██▏       | 105/493 [04:47<17:25,  2.69s/it]

104


 22%|██▏       | 106/493 [04:50<17:11,  2.67s/it]

105


 22%|██▏       | 107/493 [04:53<17:19,  2.69s/it]

106


 22%|██▏       | 108/493 [04:56<17:20,  2.70s/it]

107


 22%|██▏       | 109/493 [04:58<17:16,  2.70s/it]

108


 22%|██▏       | 110/493 [05:01<17:04,  2.67s/it]

109


 23%|██▎       | 111/493 [05:04<17:29,  2.75s/it]

110


 23%|██▎       | 112/493 [05:06<17:19,  2.73s/it]

111


 23%|██▎       | 113/493 [05:09<17:36,  2.78s/it]

112


 23%|██▎       | 114/493 [05:12<17:30,  2.77s/it]

113


 23%|██▎       | 115/493 [05:15<17:36,  2.80s/it]

114


 24%|██▎       | 116/493 [05:18<17:14,  2.74s/it]

115


 24%|██▎       | 117/493 [05:22<20:58,  3.35s/it]

116


 24%|██▍       | 118/493 [05:25<19:41,  3.15s/it]

117


 24%|██▍       | 119/493 [05:28<18:48,  3.02s/it]

118


 24%|██▍       | 120/493 [05:30<18:13,  2.93s/it]

119


 25%|██▍       | 121/493 [05:33<17:38,  2.85s/it]

120


 25%|██▍       | 122/493 [05:36<17:16,  2.79s/it]

121


 25%|██▍       | 123/493 [05:38<16:49,  2.73s/it]

122


 25%|██▌       | 124/493 [05:41<16:46,  2.73s/it]

123


 25%|██▌       | 125/493 [05:44<16:42,  2.72s/it]

124


 26%|██▌       | 126/493 [05:47<16:36,  2.72s/it]

125


 26%|██▌       | 127/493 [05:49<16:29,  2.70s/it]

126


 26%|██▌       | 128/493 [05:52<16:28,  2.71s/it]

127


 26%|██▌       | 129/493 [05:55<16:25,  2.71s/it]

128


 26%|██▋       | 130/493 [05:57<16:20,  2.70s/it]

129


 27%|██▋       | 131/493 [06:00<16:25,  2.72s/it]

130


 27%|██▋       | 132/493 [06:03<16:16,  2.70s/it]

131


 27%|██▋       | 133/493 [06:05<16:15,  2.71s/it]

132


 27%|██▋       | 134/493 [06:08<16:16,  2.72s/it]

133


 27%|██▋       | 135/493 [06:11<16:24,  2.75s/it]

134


 28%|██▊       | 136/493 [06:14<16:30,  2.78s/it]

135


 28%|██▊       | 137/493 [06:17<16:23,  2.76s/it]

136


 28%|██▊       | 138/493 [06:19<16:09,  2.73s/it]

137


 28%|██▊       | 139/493 [06:22<16:05,  2.73s/it]

138


 28%|██▊       | 140/493 [06:25<15:52,  2.70s/it]

139


 29%|██▊       | 141/493 [06:27<15:51,  2.70s/it]

140


 29%|██▉       | 142/493 [06:30<15:48,  2.70s/it]

141


 29%|██▉       | 143/493 [06:33<15:43,  2.70s/it]

142


 29%|██▉       | 144/493 [06:35<15:43,  2.70s/it]

143


 29%|██▉       | 145/493 [06:38<15:23,  2.65s/it]

144


 30%|██▉       | 146/493 [06:41<15:20,  2.65s/it]

145


 30%|██▉       | 147/493 [06:43<15:18,  2.65s/it]

146


 30%|███       | 148/493 [06:46<15:15,  2.65s/it]

147


 30%|███       | 149/493 [06:49<15:20,  2.68s/it]

148


 30%|███       | 150/493 [06:51<15:19,  2.68s/it]

149


 31%|███       | 151/493 [06:54<15:23,  2.70s/it]

150


 31%|███       | 152/493 [06:57<15:15,  2.69s/it]

151


 31%|███       | 153/493 [07:00<15:23,  2.72s/it]

152


 31%|███       | 154/493 [07:02<15:42,  2.78s/it]

153


 31%|███▏      | 155/493 [07:05<15:38,  2.78s/it]

154


 32%|███▏      | 156/493 [07:08<15:27,  2.75s/it]

155


 32%|███▏      | 157/493 [07:11<15:23,  2.75s/it]

156


 32%|███▏      | 158/493 [07:14<15:31,  2.78s/it]

157


 32%|███▏      | 159/493 [07:16<15:19,  2.75s/it]

158


 32%|███▏      | 160/493 [07:19<15:16,  2.75s/it]

159


 33%|███▎      | 161/493 [07:22<15:06,  2.73s/it]

160


 33%|███▎      | 162/493 [07:24<15:06,  2.74s/it]

161


 33%|███▎      | 163/493 [07:27<14:52,  2.70s/it]

162


 33%|███▎      | 164/493 [07:30<14:47,  2.70s/it]

163


 33%|███▎      | 165/493 [07:32<14:32,  2.66s/it]

164


 34%|███▎      | 166/493 [07:35<14:41,  2.70s/it]

165


 34%|███▍      | 167/493 [07:38<14:30,  2.67s/it]

166


 34%|███▍      | 168/493 [07:40<14:38,  2.70s/it]

167


 34%|███▍      | 169/493 [07:45<17:41,  3.28s/it]

168


 34%|███▍      | 170/493 [07:48<17:02,  3.17s/it]

169


 35%|███▍      | 171/493 [07:51<16:20,  3.05s/it]

170


 35%|███▍      | 172/493 [07:53<15:41,  2.93s/it]

171


 35%|███▌      | 173/493 [07:56<15:07,  2.84s/it]

172


 35%|███▌      | 174/493 [07:59<14:50,  2.79s/it]

173


 35%|███▌      | 175/493 [08:01<14:40,  2.77s/it]

174


 36%|███▌      | 176/493 [08:04<14:22,  2.72s/it]

175


 36%|███▌      | 177/493 [08:07<14:14,  2.70s/it]

176


 36%|███▌      | 178/493 [08:09<14:12,  2.71s/it]

177


 36%|███▋      | 179/493 [08:12<14:03,  2.69s/it]

178


 37%|███▋      | 180/493 [08:15<13:57,  2.68s/it]

179


 37%|███▋      | 181/493 [08:17<14:01,  2.70s/it]

180


 37%|███▋      | 182/493 [08:20<13:44,  2.65s/it]

181


 37%|███▋      | 183/493 [08:23<13:48,  2.67s/it]

182


 37%|███▋      | 184/493 [08:25<13:48,  2.68s/it]

183


 38%|███▊      | 185/493 [08:28<13:47,  2.69s/it]

184


 38%|███▊      | 186/493 [08:31<13:53,  2.72s/it]

185


 38%|███▊      | 187/493 [08:34<13:55,  2.73s/it]

186


 38%|███▊      | 188/493 [08:36<13:57,  2.75s/it]

187


 38%|███▊      | 189/493 [08:39<13:49,  2.73s/it]

188


 39%|███▊      | 190/493 [08:42<13:35,  2.69s/it]

189


 39%|███▊      | 191/493 [08:44<13:36,  2.70s/it]

190


 39%|███▉      | 192/493 [08:47<13:27,  2.68s/it]

191


 39%|███▉      | 193/493 [08:50<13:28,  2.69s/it]

192


 39%|███▉      | 194/493 [08:52<13:25,  2.69s/it]

193


 40%|███▉      | 195/493 [08:55<13:23,  2.70s/it]

194


 40%|███▉      | 196/493 [08:58<13:26,  2.72s/it]

195


 40%|███▉      | 197/493 [09:01<13:18,  2.70s/it]

196


 40%|████      | 198/493 [09:03<13:14,  2.69s/it]

197


 40%|████      | 199/493 [09:06<13:09,  2.68s/it]

198


 41%|████      | 200/493 [09:09<13:11,  2.70s/it]

199


 41%|████      | 201/493 [09:11<13:09,  2.70s/it]

200


 41%|████      | 202/493 [09:14<13:19,  2.75s/it]

201


 41%|████      | 203/493 [09:17<13:12,  2.73s/it]

202


 41%|████▏     | 204/493 [09:20<13:17,  2.76s/it]

203


 42%|████▏     | 205/493 [09:23<13:12,  2.75s/it]

204


 42%|████▏     | 206/493 [09:25<13:21,  2.79s/it]

205


 42%|████▏     | 207/493 [09:28<13:18,  2.79s/it]

206


 42%|████▏     | 208/493 [09:31<13:20,  2.81s/it]

207


 42%|████▏     | 209/493 [09:34<13:01,  2.75s/it]

208


 43%|████▎     | 210/493 [09:36<12:39,  2.68s/it]

209


 43%|████▎     | 211/493 [09:39<12:38,  2.69s/it]

210


 43%|████▎     | 212/493 [09:42<12:33,  2.68s/it]

211


 43%|████▎     | 213/493 [09:44<12:30,  2.68s/it]

212


 43%|████▎     | 214/493 [09:47<12:28,  2.68s/it]

213


 44%|████▎     | 215/493 [09:50<12:30,  2.70s/it]

214


 44%|████▍     | 216/493 [09:52<12:18,  2.67s/it]

215


 44%|████▍     | 217/493 [09:55<12:14,  2.66s/it]

216


 44%|████▍     | 218/493 [09:58<12:07,  2.64s/it]

217


 44%|████▍     | 219/493 [10:00<12:20,  2.70s/it]

218


 45%|████▍     | 220/493 [10:03<12:34,  2.76s/it]

219


 45%|████▍     | 221/493 [10:06<12:24,  2.74s/it]

220


 45%|████▌     | 222/493 [10:09<12:32,  2.78s/it]

221


 45%|████▌     | 223/493 [10:11<12:19,  2.74s/it]

222


 45%|████▌     | 224/493 [10:14<12:22,  2.76s/it]

223


 46%|████▌     | 225/493 [10:17<12:10,  2.73s/it]

224


 46%|████▌     | 226/493 [10:22<14:43,  3.31s/it]

225


 46%|████▌     | 227/493 [10:24<13:50,  3.12s/it]

226


 46%|████▌     | 228/493 [10:27<13:16,  3.01s/it]

227


 46%|████▋     | 229/493 [10:30<12:45,  2.90s/it]

228


 47%|████▋     | 230/493 [10:33<12:45,  2.91s/it]

229


 47%|████▋     | 231/493 [10:35<12:12,  2.80s/it]

230


 47%|████▋     | 232/493 [10:38<12:01,  2.77s/it]

231


 47%|████▋     | 233/493 [10:40<11:47,  2.72s/it]

232


 47%|████▋     | 234/493 [10:43<11:47,  2.73s/it]

233


 48%|████▊     | 235/493 [10:46<11:34,  2.69s/it]

234


 48%|████▊     | 236/493 [10:49<11:35,  2.71s/it]

235


 48%|████▊     | 237/493 [10:51<11:29,  2.69s/it]

236


 48%|████▊     | 238/493 [10:54<11:22,  2.68s/it]

237


 48%|████▊     | 239/493 [10:56<11:15,  2.66s/it]

238


 49%|████▊     | 240/493 [10:59<11:16,  2.67s/it]

239


 49%|████▉     | 241/493 [11:02<11:34,  2.76s/it]

240


 49%|████▉     | 242/493 [11:05<11:16,  2.70s/it]

241


 49%|████▉     | 243/493 [11:07<11:04,  2.66s/it]

242


 49%|████▉     | 244/493 [11:10<10:59,  2.65s/it]

243


 50%|████▉     | 245/493 [11:12<10:55,  2.64s/it]

244


 50%|████▉     | 246/493 [11:15<10:52,  2.64s/it]

245


 50%|█████     | 247/493 [11:18<10:46,  2.63s/it]

246


 50%|█████     | 248/493 [11:20<10:44,  2.63s/it]

247


 51%|█████     | 249/493 [11:23<10:38,  2.62s/it]

248


 51%|█████     | 250/493 [11:26<10:44,  2.65s/it]

249


 51%|█████     | 251/493 [11:28<10:38,  2.64s/it]

250


 51%|█████     | 252/493 [11:31<10:55,  2.72s/it]

251


 51%|█████▏    | 253/493 [11:34<10:44,  2.68s/it]

252


 52%|█████▏    | 254/493 [11:36<10:37,  2.67s/it]

253


 52%|█████▏    | 255/493 [11:39<10:39,  2.69s/it]

254


 52%|█████▏    | 256/493 [11:42<10:35,  2.68s/it]

255


 52%|█████▏    | 257/493 [11:45<10:37,  2.70s/it]

256


 52%|█████▏    | 258/493 [11:47<10:26,  2.66s/it]

257


 53%|█████▎    | 259/493 [11:50<10:25,  2.67s/it]

258


 53%|█████▎    | 260/493 [11:52<10:21,  2.67s/it]

259


 53%|█████▎    | 261/493 [11:55<10:15,  2.65s/it]

260


 53%|█████▎    | 262/493 [11:58<10:16,  2.67s/it]

261


 53%|█████▎    | 263/493 [12:01<10:21,  2.70s/it]

262


 54%|█████▎    | 264/493 [12:03<10:21,  2.71s/it]

263


 54%|█████▍    | 265/493 [12:06<10:18,  2.71s/it]

264


 54%|█████▍    | 266/493 [12:09<10:12,  2.70s/it]

265


 54%|█████▍    | 267/493 [12:11<10:08,  2.69s/it]

266


 54%|█████▍    | 268/493 [12:14<10:05,  2.69s/it]

267


 55%|█████▍    | 269/493 [12:17<09:59,  2.68s/it]

268


 55%|█████▍    | 270/493 [12:19<09:56,  2.68s/it]

269


 55%|█████▍    | 271/493 [12:22<09:52,  2.67s/it]

270


 55%|█████▌    | 272/493 [12:25<09:48,  2.66s/it]

271


 55%|█████▌    | 273/493 [12:27<09:47,  2.67s/it]

272


 56%|█████▌    | 274/493 [12:30<09:52,  2.71s/it]

273


 56%|█████▌    | 275/493 [12:33<09:59,  2.75s/it]

274


 56%|█████▌    | 276/493 [12:36<10:06,  2.79s/it]

275


 56%|█████▌    | 277/493 [12:39<10:04,  2.80s/it]

276


 56%|█████▋    | 278/493 [12:42<10:05,  2.82s/it]

277


 57%|█████▋    | 279/493 [12:44<10:04,  2.82s/it]

278


 57%|█████▋    | 280/493 [12:47<09:49,  2.77s/it]

279


 57%|█████▋    | 281/493 [12:50<09:42,  2.75s/it]

280


 57%|█████▋    | 282/493 [12:54<11:33,  3.29s/it]

281


 57%|█████▋    | 283/493 [12:57<10:54,  3.11s/it]

282


 58%|█████▊    | 284/493 [13:00<10:25,  2.99s/it]

283


 58%|█████▊    | 285/493 [13:02<10:06,  2.92s/it]

284


 58%|█████▊    | 286/493 [13:05<09:45,  2.83s/it]

285


 58%|█████▊    | 287/493 [13:08<09:33,  2.78s/it]

286


 58%|█████▊    | 288/493 [13:10<09:20,  2.74s/it]

287


 59%|█████▊    | 289/493 [13:13<09:20,  2.75s/it]

288


 59%|█████▉    | 290/493 [13:16<09:08,  2.70s/it]

289


 59%|█████▉    | 291/493 [13:18<09:04,  2.70s/it]

290


 59%|█████▉    | 292/493 [13:21<09:01,  2.70s/it]

291


 59%|█████▉    | 293/493 [13:24<09:00,  2.70s/it]

292


 60%|█████▉    | 294/493 [13:27<08:58,  2.71s/it]

293


 60%|█████▉    | 295/493 [13:29<08:53,  2.69s/it]

294


 60%|██████    | 296/493 [13:32<08:47,  2.68s/it]

295


 60%|██████    | 297/493 [13:35<08:44,  2.67s/it]

296


 60%|██████    | 298/493 [13:37<08:47,  2.70s/it]

297


 61%|██████    | 299/493 [13:40<08:46,  2.71s/it]

298


 61%|██████    | 300/493 [13:43<08:40,  2.69s/it]

299


 61%|██████    | 301/493 [13:45<08:34,  2.68s/it]

300


 61%|██████▏   | 302/493 [13:48<08:38,  2.72s/it]

301


 61%|██████▏   | 303/493 [13:51<08:34,  2.71s/it]

302


 62%|██████▏   | 304/493 [13:54<08:32,  2.71s/it]

303


 62%|██████▏   | 305/493 [13:56<08:24,  2.68s/it]

304


 62%|██████▏   | 306/493 [13:59<08:23,  2.69s/it]

305


 62%|██████▏   | 307/493 [14:02<08:23,  2.71s/it]

306


 62%|██████▏   | 308/493 [14:04<08:20,  2.70s/it]

307


 63%|██████▎   | 309/493 [14:07<08:14,  2.69s/it]

308


 63%|██████▎   | 310/493 [14:10<08:09,  2.67s/it]

309


 63%|██████▎   | 311/493 [14:12<08:04,  2.66s/it]

310


 63%|██████▎   | 312/493 [14:15<07:59,  2.65s/it]

311


 63%|██████▎   | 313/493 [14:18<08:00,  2.67s/it]

312


 64%|██████▎   | 314/493 [14:20<07:57,  2.67s/it]

313


 64%|██████▍   | 315/493 [14:23<07:51,  2.65s/it]

314


 64%|██████▍   | 316/493 [14:26<07:49,  2.65s/it]

315


 64%|██████▍   | 317/493 [14:28<07:53,  2.69s/it]

316


 65%|██████▍   | 318/493 [14:31<07:51,  2.70s/it]

317


 65%|██████▍   | 319/493 [14:34<07:50,  2.70s/it]

318


 65%|██████▍   | 320/493 [14:36<07:50,  2.72s/it]

319


 65%|██████▌   | 321/493 [14:39<07:47,  2.72s/it]

320


 65%|██████▌   | 322/493 [14:42<07:47,  2.74s/it]

321


 66%|██████▌   | 323/493 [14:45<07:42,  2.72s/it]

322


 66%|██████▌   | 324/493 [14:47<07:38,  2.71s/it]

323


 66%|██████▌   | 325/493 [14:50<07:36,  2.72s/it]

324


 66%|██████▌   | 326/493 [14:53<07:28,  2.69s/it]

325


 66%|██████▋   | 327/493 [14:55<07:27,  2.70s/it]

326


 67%|██████▋   | 328/493 [14:58<07:21,  2.68s/it]

327


 67%|██████▋   | 329/493 [15:01<07:14,  2.65s/it]

328


 67%|██████▋   | 330/493 [15:03<07:11,  2.65s/it]

329


 67%|██████▋   | 331/493 [15:06<07:09,  2.65s/it]

330


 67%|██████▋   | 332/493 [15:09<07:09,  2.67s/it]

331


 68%|██████▊   | 333/493 [15:11<07:04,  2.65s/it]

332


 68%|██████▊   | 334/493 [15:14<07:05,  2.68s/it]

333


 68%|██████▊   | 335/493 [15:17<07:02,  2.67s/it]

334


 68%|██████▊   | 336/493 [15:19<07:05,  2.71s/it]

335


 68%|██████▊   | 337/493 [15:24<08:29,  3.27s/it]

336


 69%|██████▊   | 338/493 [15:27<07:59,  3.09s/it]

337


 69%|██████▉   | 339/493 [15:30<07:42,  3.00s/it]

338


 69%|██████▉   | 340/493 [15:32<07:22,  2.90s/it]

339


 69%|██████▉   | 341/493 [15:35<07:13,  2.85s/it]

340


 69%|██████▉   | 342/493 [15:38<06:59,  2.78s/it]

341


 70%|██████▉   | 343/493 [15:40<06:55,  2.77s/it]

342


 70%|██████▉   | 344/493 [15:43<06:53,  2.77s/it]

343


 70%|██████▉   | 345/493 [15:46<06:44,  2.73s/it]

344


 70%|███████   | 346/493 [15:48<06:42,  2.74s/it]

345


 70%|███████   | 347/493 [15:51<06:34,  2.70s/it]

346


 71%|███████   | 348/493 [15:54<06:33,  2.72s/it]

347


 71%|███████   | 349/493 [15:56<06:27,  2.69s/it]

348


 71%|███████   | 350/493 [15:59<06:27,  2.71s/it]

349


 71%|███████   | 351/493 [16:02<06:22,  2.69s/it]

350


 71%|███████▏  | 352/493 [16:04<06:18,  2.68s/it]

351


 72%|███████▏  | 353/493 [16:07<06:16,  2.69s/it]

352


 72%|███████▏  | 354/493 [16:10<06:15,  2.70s/it]

353


 72%|███████▏  | 355/493 [16:13<06:08,  2.67s/it]

354


 72%|███████▏  | 356/493 [16:15<06:00,  2.63s/it]

355


 72%|███████▏  | 357/493 [16:18<06:00,  2.65s/it]

356


 73%|███████▎  | 358/493 [16:20<06:00,  2.67s/it]

357


 73%|███████▎  | 359/493 [16:23<05:59,  2.68s/it]

358


 73%|███████▎  | 360/493 [16:26<05:56,  2.68s/it]

359


 73%|███████▎  | 361/493 [16:29<05:54,  2.69s/it]

360


 73%|███████▎  | 362/493 [16:31<05:47,  2.65s/it]

361


 74%|███████▎  | 363/493 [16:34<05:48,  2.68s/it]

362


 74%|███████▍  | 364/493 [16:37<05:50,  2.71s/it]

363


 74%|███████▍  | 365/493 [16:39<05:47,  2.72s/it]

364


 74%|███████▍  | 366/493 [16:42<05:43,  2.71s/it]

365


 74%|███████▍  | 367/493 [16:45<05:41,  2.71s/it]

366


 75%|███████▍  | 368/493 [16:47<05:37,  2.70s/it]

367


 75%|███████▍  | 369/493 [16:50<05:37,  2.72s/it]

368


 75%|███████▌  | 370/493 [16:53<05:34,  2.72s/it]

369


 75%|███████▌  | 371/493 [16:56<05:30,  2.71s/it]

370


 75%|███████▌  | 372/493 [16:58<05:25,  2.69s/it]

371


 76%|███████▌  | 373/493 [17:01<05:23,  2.70s/it]

372


 76%|███████▌  | 374/493 [17:04<05:18,  2.68s/it]

373


 76%|███████▌  | 375/493 [17:06<05:15,  2.67s/it]

374


 76%|███████▋  | 376/493 [17:09<05:10,  2.66s/it]

375


 76%|███████▋  | 377/493 [17:12<05:07,  2.65s/it]

376


 77%|███████▋  | 378/493 [17:14<05:07,  2.67s/it]

377


 77%|███████▋  | 379/493 [17:17<05:06,  2.69s/it]

378


 77%|███████▋  | 380/493 [17:20<05:05,  2.70s/it]

379


 77%|███████▋  | 381/493 [17:22<05:01,  2.69s/it]

380


 77%|███████▋  | 382/493 [17:25<05:01,  2.72s/it]

381


 78%|███████▊  | 383/493 [17:28<04:57,  2.70s/it]

382


 78%|███████▊  | 384/493 [17:31<04:56,  2.72s/it]

383


 78%|███████▊  | 385/493 [17:33<04:54,  2.73s/it]

384


 78%|███████▊  | 386/493 [17:36<04:48,  2.70s/it]

385


 78%|███████▊  | 387/493 [17:39<04:43,  2.68s/it]

386


 79%|███████▊  | 388/493 [17:43<05:43,  3.27s/it]

387


 79%|███████▉  | 389/493 [17:46<05:22,  3.10s/it]

388


 79%|███████▉  | 390/493 [17:49<05:08,  2.99s/it]

389


 79%|███████▉  | 391/493 [17:51<04:53,  2.88s/it]

390


 80%|███████▉  | 392/493 [17:54<04:48,  2.85s/it]

391


 80%|███████▉  | 393/493 [17:57<04:41,  2.81s/it]

392


 80%|███████▉  | 394/493 [18:00<04:35,  2.78s/it]

393


 80%|████████  | 395/493 [18:02<04:29,  2.75s/it]

394


 80%|████████  | 396/493 [18:05<04:26,  2.75s/it]

395


 81%|████████  | 397/493 [18:08<04:19,  2.70s/it]

396


 81%|████████  | 398/493 [18:10<04:15,  2.69s/it]

397


 81%|████████  | 399/493 [18:13<04:10,  2.66s/it]

398


 81%|████████  | 400/493 [18:16<04:09,  2.68s/it]

399


 81%|████████▏ | 401/493 [18:18<04:07,  2.69s/it]

400


 82%|████████▏ | 402/493 [18:21<04:03,  2.68s/it]

401


 82%|████████▏ | 403/493 [18:24<03:59,  2.67s/it]

402


 82%|████████▏ | 404/493 [18:26<03:57,  2.66s/it]

403


 82%|████████▏ | 405/493 [18:29<03:53,  2.66s/it]

404


 82%|████████▏ | 406/493 [18:31<03:50,  2.65s/it]

405


 83%|████████▎ | 407/493 [18:34<03:48,  2.66s/it]

406


 83%|████████▎ | 408/493 [18:37<03:46,  2.67s/it]

407


 83%|████████▎ | 409/493 [18:40<03:44,  2.67s/it]

408


 83%|████████▎ | 410/493 [18:42<03:40,  2.66s/it]

409


 83%|████████▎ | 411/493 [18:45<03:42,  2.71s/it]

410


 84%|████████▎ | 412/493 [18:48<03:35,  2.66s/it]

411


 84%|████████▍ | 413/493 [18:50<03:37,  2.72s/it]

412


 84%|████████▍ | 414/493 [18:53<03:32,  2.69s/it]

413


 84%|████████▍ | 415/493 [18:56<03:29,  2.69s/it]

414


 84%|████████▍ | 416/493 [18:58<03:27,  2.69s/it]

415


 85%|████████▍ | 417/493 [19:01<03:27,  2.73s/it]

416


 85%|████████▍ | 418/493 [19:04<03:22,  2.70s/it]

417


 85%|████████▍ | 419/493 [19:07<03:20,  2.71s/it]

418


 85%|████████▌ | 420/493 [19:09<03:16,  2.69s/it]

419


 85%|████████▌ | 421/493 [19:12<03:13,  2.69s/it]

420


 86%|████████▌ | 422/493 [19:15<03:11,  2.70s/it]

421


 86%|████████▌ | 423/493 [19:17<03:10,  2.72s/it]

422


 86%|████████▌ | 424/493 [19:20<03:05,  2.69s/it]

423


 86%|████████▌ | 425/493 [19:23<03:05,  2.73s/it]

424


 86%|████████▋ | 426/493 [19:26<03:02,  2.72s/it]

425


 87%|████████▋ | 427/493 [19:28<02:59,  2.72s/it]

426


 87%|████████▋ | 428/493 [19:31<02:56,  2.71s/it]

427


 87%|████████▋ | 429/493 [19:34<02:53,  2.72s/it]

428


 87%|████████▋ | 430/493 [19:36<02:50,  2.70s/it]

429


 87%|████████▋ | 431/493 [19:39<02:47,  2.71s/it]

430


 88%|████████▊ | 432/493 [19:42<02:45,  2.71s/it]

431


 88%|████████▊ | 433/493 [19:45<02:49,  2.82s/it]

432


 88%|████████▊ | 434/493 [19:48<02:46,  2.82s/it]

433


 88%|████████▊ | 435/493 [19:50<02:39,  2.76s/it]

434


 88%|████████▊ | 436/493 [19:53<02:35,  2.73s/it]

435


 89%|████████▊ | 437/493 [19:58<03:08,  3.36s/it]

436


 89%|████████▉ | 438/493 [20:01<02:56,  3.22s/it]

437


 89%|████████▉ | 439/493 [20:03<02:45,  3.07s/it]

438


 89%|████████▉ | 440/493 [20:06<02:37,  2.98s/it]

439


 89%|████████▉ | 441/493 [20:09<02:32,  2.94s/it]

440


 90%|████████▉ | 442/493 [20:12<02:28,  2.91s/it]

441


 90%|████████▉ | 443/493 [20:15<02:22,  2.86s/it]

442


 90%|█████████ | 444/493 [20:17<02:20,  2.87s/it]

443


 90%|█████████ | 445/493 [20:20<02:15,  2.82s/it]

444


 90%|█████████ | 446/493 [20:23<02:11,  2.79s/it]

445


 91%|█████████ | 447/493 [20:26<02:06,  2.74s/it]

446


 91%|█████████ | 448/493 [20:28<02:03,  2.74s/it]

447


 91%|█████████ | 449/493 [20:31<01:59,  2.72s/it]

448


 91%|█████████▏| 450/493 [20:34<01:56,  2.70s/it]

449


 91%|█████████▏| 451/493 [20:36<01:52,  2.69s/it]

450


 92%|█████████▏| 452/493 [20:39<01:49,  2.67s/it]

451


 92%|█████████▏| 453/493 [20:42<01:46,  2.66s/it]

452


 92%|█████████▏| 454/493 [20:44<01:43,  2.64s/it]

453


 92%|█████████▏| 455/493 [20:47<01:40,  2.65s/it]

454


 92%|█████████▏| 456/493 [20:49<01:37,  2.64s/it]

455


 93%|█████████▎| 457/493 [20:52<01:35,  2.66s/it]

456


 93%|█████████▎| 458/493 [20:55<01:33,  2.66s/it]

457


 93%|█████████▎| 459/493 [20:58<01:32,  2.73s/it]

458


 93%|█████████▎| 460/493 [21:00<01:28,  2.69s/it]

459


 94%|█████████▎| 461/493 [21:03<01:27,  2.73s/it]

460


 94%|█████████▎| 462/493 [21:06<01:25,  2.75s/it]

461


 94%|█████████▍| 463/493 [21:09<01:23,  2.77s/it]

462


 94%|█████████▍| 464/493 [21:11<01:19,  2.76s/it]

463


 94%|█████████▍| 465/493 [21:15<01:19,  2.85s/it]

464


 95%|█████████▍| 466/493 [21:17<01:15,  2.80s/it]

465


 95%|█████████▍| 467/493 [21:20<01:11,  2.76s/it]

466


 95%|█████████▍| 468/493 [21:22<01:07,  2.72s/it]

467


 95%|█████████▌| 469/493 [21:25<01:05,  2.74s/it]

468


 95%|█████████▌| 470/493 [21:28<01:02,  2.71s/it]

469


 96%|█████████▌| 471/493 [21:31<00:59,  2.71s/it]

470


 96%|█████████▌| 472/493 [21:33<00:56,  2.68s/it]

471


 96%|█████████▌| 473/493 [21:36<00:53,  2.68s/it]

472


 96%|█████████▌| 474/493 [21:39<00:50,  2.67s/it]

473


 96%|█████████▋| 475/493 [21:41<00:48,  2.68s/it]

474


 97%|█████████▋| 476/493 [21:44<00:45,  2.66s/it]

475


 97%|█████████▋| 477/493 [21:46<00:42,  2.64s/it]

476


 97%|█████████▋| 478/493 [21:49<00:39,  2.64s/it]

477


 97%|█████████▋| 479/493 [21:52<00:36,  2.64s/it]

478


 97%|█████████▋| 480/493 [21:54<00:34,  2.65s/it]

479


 98%|█████████▊| 481/493 [21:57<00:31,  2.65s/it]

480


 98%|█████████▊| 482/493 [22:00<00:29,  2.66s/it]

481


 98%|█████████▊| 483/493 [22:02<00:26,  2.67s/it]

482


 98%|█████████▊| 484/493 [22:05<00:23,  2.66s/it]

483


 98%|█████████▊| 485/493 [22:08<00:21,  2.64s/it]

484


 99%|█████████▊| 486/493 [22:10<00:18,  2.67s/it]

485


 99%|█████████▉| 487/493 [22:13<00:15,  2.65s/it]

486


 99%|█████████▉| 488/493 [22:16<00:13,  2.65s/it]

487


 99%|█████████▉| 489/493 [22:18<00:10,  2.66s/it]

488


 99%|█████████▉| 490/493 [22:21<00:07,  2.66s/it]

489


100%|█████████▉| 491/493 [22:24<00:05,  2.67s/it]

490


100%|█████████▉| 492/493 [22:28<00:03,  3.26s/it]

491


100%|██████████| 493/493 [22:31<00:00,  2.74s/it]


492


  0%|          | 1/493 [00:02<22:04,  2.69s/it]

128


  0%|          | 2/493 [00:05<21:33,  2.63s/it]

129


  1%|          | 3/493 [00:08<22:01,  2.70s/it]

130


  1%|          | 4/493 [00:10<21:50,  2.68s/it]

131


  1%|          | 5/493 [00:13<21:51,  2.69s/it]

132


  1%|          | 6/493 [00:16<21:45,  2.68s/it]

133


  1%|▏         | 7/493 [00:18<21:37,  2.67s/it]

134


  2%|▏         | 8/493 [00:21<21:32,  2.67s/it]

135


  2%|▏         | 9/493 [00:23<21:20,  2.65s/it]

136


  2%|▏         | 10/493 [00:26<21:28,  2.67s/it]

137


  2%|▏         | 11/493 [00:29<21:24,  2.67s/it]

138


  2%|▏         | 12/493 [00:32<21:27,  2.68s/it]

139


  3%|▎         | 13/493 [00:34<21:25,  2.68s/it]

140


  3%|▎         | 14/493 [00:37<21:27,  2.69s/it]

141


  3%|▎         | 15/493 [00:40<21:14,  2.67s/it]

142


  3%|▎         | 16/493 [00:42<21:16,  2.68s/it]

143


  3%|▎         | 17/493 [00:45<21:16,  2.68s/it]

144


  4%|▎         | 18/493 [00:48<21:15,  2.69s/it]

145


  4%|▍         | 19/493 [00:50<21:05,  2.67s/it]

146


  4%|▍         | 20/493 [00:53<21:13,  2.69s/it]

147


  4%|▍         | 21/493 [00:56<21:24,  2.72s/it]

148


  4%|▍         | 22/493 [00:59<21:16,  2.71s/it]

149


  5%|▍         | 23/493 [01:02<22:23,  2.86s/it]

150


  5%|▍         | 24/493 [01:05<22:19,  2.86s/it]

151


  5%|▌         | 25/493 [01:07<22:14,  2.85s/it]

152


  5%|▌         | 26/493 [01:10<22:03,  2.83s/it]

153


  5%|▌         | 27/493 [01:13<22:24,  2.89s/it]

154


  5%|▌         | 27/493 [01:15<21:40,  2.79s/it]


KeyboardInterrupt: 

In [None]:
validation_users = list(set([int(x) for x in validation_edges[0, :]]))
validation_df[validation_df.user_id == id_to_user[0]]

In [None]:
validation_edges[:, validation_edges[0] == 0]

In [None]:
def get_user_positive_items(edge_index):
    """Generates dictionary of positive items for each user

    Args:
        edge_index (torch.Tensor): 2 by N list of edges

    Returns:
        dict: dictionary of positive items for each user
    """
    user_pos_items = {}
    for i in range(edge_index.shape[1]):
        user = edge_index[0][i].item()
        item = edge_index[1][i].item()
        if user not in user_pos_items:
            user_pos_items[user] = []
        user_pos_items[user].append(item)
    return user_pos_items

In [None]:
print()