***
***

# Assignment 2 - Extend My Playlist - 60 points

- This assignment will see you use a dataset of Spotify playlists, and build a GNN to build a playlist-track RecSys (Recommendation System). We'll use a type of layer called LightGCN for this task.
- You will almost certainly need a GPU for this assignment. The free tier GPUs of colab should suffice. So write your code with that in mind.
- The dataset you have been shared is a small part of a much larger dataset so the free tier colab GPU can handle it. Do **NOT** share the dataset with other people. We'll release details about the dataset after the assignment is over if you would like access to the full dataset.


## Marking Scheme
1. Code (45 points)
2. Quality of Recommendation (15 points)
You will receive a bonus 10 points for being able to provide a visualisation of the embeddings your model learned. Hints for this are at the end of the assignment.
You will receive a bonus 5 points for well-documented code. So use your comments well, and remember to follow PEP guidelines.


### Details
- We have given a weak boilerplate code for you, with instructions and points allotted for each section. Following the instructions will be the easiest way to get the most grades.
- We are not testing your literature review skils here. The dataset, modules to use, layers you'll need are all given below. Don't waste time hunting for better ones, as there is no objective way to measure the brilliance of your results.
- 15 points are for the quality of the recommendations as tested on our actual playlists. Each of the mentors will run your model on one of their playlists, and rate your model out of 5. So for three mentors, that makes 15 points. This is a rather subjective, however every mentor will grade every notebook for this part of the criterion, so it will even out.

### Boilerplate -
0. Setting up
1. Data class creations
2. Data loading and graph creation
3. Subgraphing
4. Dataset creation
5. Model design
6. Training
7. Inference

### What score to expect?
- Model that was not trained (15 points)
- Model that has been trained (30 points)
- Trained model that can do inference on new data (45 points)
- Trained model that provides recommendations which match the vibe of the playlist (60 points)

## Tentative Pass Criteria
40 points are needed to pass this assignment. Which means your model needs to be trained and be able to perform inference for your submission to count as a pass.

***

# Code Starts Here



## Section 0 - Installing modules and Loading the Dataset (0 points)

This section is already done for you and you do not need to change anything here. Go through it to understand what kind of data is available and what modules are pre-installed for you.

### Installation of Libraries

In [1]:
# If you are running this code locally, remember to comment this section after you have installed the code.
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
!pip install torch-geometric
!pip install umap-learn

Looking in links: https://data.pyg.org/whl/torch-1.13.1+cu116.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.1%2Bpt113cu116-cp310-cp310-linux_x86_64.whl (9.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m77.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.1+pt113cu116
Looking in links: https://data.pyg.org/whl/torch-1.13.1+cu116.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_sparse-0.6.17%2Bpt113cu116-cp310-cp310-linux_x86_64.whl (4.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m50.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.17+pt113cu116
Collecting torch-geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━

### General imports

In [2]:
import json
import pprint
from pathlib import Path as Data_Path
import os
from os.path import isfile, join
import pickle
import random

import numpy as np
import networkx as nx
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

### ML imports

In [3]:
# Import relevant ML libraries
from typing import Optional, Union

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Embedding, ModuleList, Linear
import torch.nn.functional as F

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch.nn.modules.loss import _Loss

from torch_geometric.nn.conv import LGConv
from torch_geometric.typing import Adj, OptTensor, SparseTensor

print(f"Torch version: {torch.__version__}; Torch-cuda version: {torch.version.cuda};\
         Torch Geometric version: {torch_geometric.__version__}.")



Torch version: 2.0.1+cu118; Torch-cuda version: 11.8;         Torch Geometric version: 2.3.1.


### Seeding

In [4]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

### Loading Data

- The dataset is provided to you as a set of 20 files with ~20000 playlists from the 2010s. These are stored in json files.
- Extract the zip file we send in your current working directory. The data files should be in A2/data/*. You should see 20 files titled ```f'mpd.slice.{x}-{x+999}.json'``` in the data directory.
This is a sample file

### Useful Classes


In [6]:
"""
The original data was stored in JSON files, which contain playlists, which themselves contain tracks.
Thus, we define three classes:
  Track    : contains information for a specific track (its id, name, etc.)
  Playlist : contains information for a specific playlist (its id, name, etc. as well as a list of Tracks)
  JSONFile : contains the loaded json file and stores a dictionary of all of the Playlists

Note: If you want to use Artist info, you may want to create an Artist class
"""

class Track:
  """
  Simple class for a track, containing its attributes:
    1. URI (a unique id)
    2. Name
    3. Artist info (URI and name)
    4. Parent playlist
  """

  def __init__(self, track_dict, playlist):
    self.uri = track_dict["track_uri"]
    self.name = track_dict["track_name"]
    self.artist_uri = track_dict["artist_uri"]
    self.artist_name = track_dict["artist_name"]
    self.playlist = playlist

  def __str__(self):
    return f"Track {self.uri} called {self.name} by {self.artist_uri} ({self.artist_name}) in playlist {self.playlist}."

  def __repr__(self):
    return f"Track {self.uri}"

class Playlist:
  """
  Simple class for a playlist, containing its attributes:
    1. Name (playlist and its associated index (the data is a subset, so every index hasn't been shared with you))
    2. Title (playlist title in the Spotify dataset)
    3. Loaded dictionary from the raw json for the playlist
    4. Dictionary of tracks (track_uri : Track), populated by .load_tracks()
    5. List of artists uris
  """

  def __init__(self, json_data, index):

    self.name = f"playlist_{index}"
    self.title = json_data["name"]
    self.data = json_data

    self.tracks = {}
    self.artists = []

  def load_tracks(self):
    """ Call this function to load all of the tracks in the json data for the playlist."""

    tracks_list = self.data["tracks"]
    self.tracks = {x["track_uri"] : Track(x, self.name) for x in tracks_list}
    self.artists = [x["artist_uri"] for x in tracks_list]

  def __str__(self):
    return f"Playlist {self.name} with {len(self.tracks)} tracks loaded."

  def __repr__(self):
    return f"Playlist {self.name}"

class JSONFile:
  """
  Simple class for a JSON file, containing its attributes:
    1. File Name
    2. Index to begin numbering playlists at
    3. Loaded dictionary from the raw json for the full file
    4. Dictionary of playlists (name : Playlist), populated by .process_file()
  """

  def __init__(self, data_path, file_name, start_index):

    self.file_name = file_name
    self.start_index = start_index

    with open(join(data_path, file_name)) as json_file:
      json_data = json.load(json_file)
    self.data = json_data

    self.playlists = {}

  def process_file(self):
    """ Call this function to load all of the playlists in the json data."""

    for i, playlist_json in enumerate(self.data["playlists"]):
      playlist = Playlist(playlist_json, self.start_index + i)
      playlist.load_tracks()
      self.playlists[playlist.name] = playlist

  def __str__(self):
    return f"JSON {self.file_name} has {len(self.playlists)} playlists loaded."

  def __repr__(self):
    return self.file_name


In [7]:
import pickle

with open("playlists_idx.dat","rb") as f:
    playlists = pickle.load(f)

with open("tracks_idx.dat","rb") as f:
    tracks = pickle.load(f)
n_playlists = len(playlists)
n_tracks = len(tracks)


In [23]:
graph_data = torch.load('playlist_data.pt')

Split the data into train-test-split sets.

In [24]:
transform = RandomLinkSplit(
    is_undirected=True,
    add_negative_train_samples=False,
    neg_sampling_ratio=0,
    num_val=0.15, num_test=0.15
)
train_split, val_split, test_split = transform(graph_data)

In [25]:
# note these are stored as float32, we need them to be int64 for future training

# Edge index: message passing edges
train_split.edge_index = train_split.edge_index.type(torch.int64)
val_split.edge_index = val_split.edge_index.type(torch.int64)
test_split.edge_index = test_split.edge_index.type(torch.int64)
# Edge label index: supervision edges
train_split.edge_label_index = train_split.edge_label_index.type(torch.int64)
val_split.edge_label_index = val_split.edge_label_index.type(torch.int64)
test_split.edge_label_index = test_split.edge_label_index.type(torch.int64)

print(f"Train set has {train_split.edge_label_index.shape[1]} positives upervision edges")
print(f"Validation set has {val_split.edge_label_index.shape[1]} positive supervision edges")
print(f"Test set has {test_split.edge_label_index.shape[1]} positive supervision edges")

print(f"Train set has {train_split.edge_index.shape[1]} message passing edges")
print(f"Validation set has {val_split.edge_index.shape[1]} message passing edges")
print(f"Test set has {test_split.edge_index.shape[1]} message passing edges")

Train set has 219346 positives upervision edges
Validation set has 47002 positive supervision edges
Test set has 47002 positive supervision edges
Train set has 438692 message passing edges
Validation set has 438692 message passing edges
Test set has 532696 message passing edges


### The GCN class
As mentioned, the primary model type we need here is a GCN. Use this boilerplate code to define your GCN code. This is merely a template, you can change it up if you think something else would be more intuitive.

In [26]:
class GCN(torch.nn.Module):
    """
      Here we adapt the LightGCN model from Torch Geometric for our purposes. We allow
      for customizable convolutional layers, custom embeddings. In addition, we deifne some
      additional custom functions.

    """

    def __init__(
        self,
        num_nodes: int,
        embedding_dim: int,
        num_layers: int,
        alpha: Optional[Union[float, Tensor]] = None,
        alpha_learnable = False,
        conv_layer = "LGC",
        name = None,
        **kwargs,
    ):
        super().__init__()
        alpha_string = "alpha" if alpha_learnable else ""
        self.name = f"LGCN_{conv_layer}_{num_layers}_e{embedding_dim}_nodes{num_nodes}_{alpha_string}"
        self.num_nodes = num_nodes
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        if alpha_learnable == True:
          alpha_vals = torch.rand(num_layers+1)
          alpha = nn.Parameter(alpha_vals/torch.sum(alpha_vals))
          print(f"Alpha learnable, initialized to: {alpha.softmax(dim=-1)}")
        else:
          if alpha is None:
              alpha = 1. / (num_layers + 1)

          if isinstance(alpha, Tensor):
              assert alpha.size(0) == num_layers + 1
          else:
              alpha = torch.tensor([alpha] * (num_layers + 1))

        self.register_buffer('alpha', alpha)

        self.embedding = Embedding(num_nodes, embedding_dim)

        # initialize convolutional layers
        self.conv_layer = conv_layer
        if conv_layer == "LGC":
          self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)])

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embedding.weight)
        for conv in self.convs:
            conv.reset_parameters()

    def get_embedding(self, edge_index: Adj) -> Tensor:
        x = self.embedding.weight

        weights = self.alpha.softmax(dim=-1)
        out = x * weights[0]

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            if self.conv_layer == "GAT":
              x = self.linears[i](x)
            out = out + x * weights[i + 1]

        return out

    def initialize_embeddings(self, data):
      # initialize with the data node features
        self.embedding.weight.data.copy_(data.node_feature)


    def forward(self, edge_index: Adj,
                edge_label_index: OptTensor = None) -> Tensor:
        if edge_label_index is None:
            if isinstance(edge_index, SparseTensor):
                edge_label_index = torch.stack(edge_index.coo()[:2], dim=0)
            else:
                edge_label_index = edge_index

        out = self.get_embedding(edge_index)

        return self.predict_link_embedding(out, edge_label_index)

    def predict_link(self, edge_index: Adj, edge_label_index: OptTensor = None,
                     prob: bool = False) -> Tensor:

        pred = self(edge_index, edge_label_index).sigmoid()
        return pred if prob else pred.round()

    def predict_link_embedding(self, embed: Adj, edge_label_index: Adj) -> Tensor:

        embed_src = embed[edge_label_index[0]]
        embed_dst = embed[edge_label_index[1]]
        return (embed_src * embed_dst).sum(dim=-1)


    def recommend(self, edge_index: Adj, src_index: OptTensor = None,
                  dst_index: OptTensor = None, k: int = 1) -> Tensor:
        out_src = out_dst = self.get_embedding(edge_index)

        if src_index is not None:
            out_src = out_src[src_index]

        if dst_index is not None:
            out_dst = out_dst[dst_index]

        pred = out_src @ out_dst.t()
        top_index = pred.topk(k, dim=-1).indices

        if dst_index is not None:  # Map local top-indices to original indices.
            top_index = dst_index[top_index.view(-1)].view(*top_index.size())

        return top_index


    def link_pred_loss(self, pred: Tensor, edge_label: Tensor,
                       **kwargs) -> Tensor:
        loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)
        return loss_fn(pred, edge_label.to(pred.dtype))


    def recommendation_loss(self, pos_edge_rank: Tensor, neg_edge_rank: Tensor,
                            lambda_reg: float = 1e-4, **kwargs) -> Tensor:
        r"""Computes the model loss for a ranking objective via the Bayesian
        Personalized Ranking (BPR) loss."""
        loss_fn = BPRLoss(lambda_reg, **kwargs)
        return loss_fn(pos_edge_rank, neg_edge_rank, self.embedding.weight)

    def bpr_loss(self, pos_scores, neg_scores):
      return - torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.num_nodes}, '
                f'{self.embedding_dim}, num_layers={self.num_layers})')


You'll need a new kind of Loss function for this task, so here's a class for implementing BPR, Bayesian Personalised Ranking Loss. We've implemented this for you. Hopefully it helps. You can change function names and stuff, however you are forbidden from changing the mathematical formula -

\begin{equation*}
    \text{BPR Loss}(i) = \frac{1}{|\mathcal{E}(i)|} \underset{{(i, j_{+}) \in \mathcal{E}(i)}}{\sum} \log \sigma \left( \text{score}(i, j_+) - \text{score}(i, j_-) \right)
\end{equation*}

In [27]:
class BPRLoss(_Loss):
    r"""The Bayesian Personalized Ranking (BPR) loss.

    The BPR loss is a pairwise loss that encourages the prediction of an
    observed entry to be higher than its unobserved counterparts
    (see `here <https://arxiv.org/abs/2002.02126>`__).

    .. math::
        L_{\text{BPR}} = - \sum_{u=1}^{M} \sum_{i \in \mathcal{N}_u}
        \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj})
        + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2

    where :math:`lambda` controls the :math:`L_2` regularization strength.
    We compute the mean BPR loss for simplicity.

    Args:
        lambda_reg (float, optional): The :math:`L_2` regularization strength
            (default: 0).
        **kwargs (optional): Additional arguments of the underlying
            :class:`torch.nn.modules.loss._Loss` class.
    """
    __constants__ = ['lambda_reg']
    lambda_reg: float

    def __init__(self, lambda_reg: float = 0, **kwargs):
        super().__init__(None, None, "sum", **kwargs)
        self.lambda_reg = lambda_reg

    def forward(self, positives: Tensor, negatives: Tensor,
                parameters: Tensor = None) -> Tensor:
        r"""Compute the mean Bayesian Personalized Ranking (BPR) loss.

        .. note::

            The i-th entry in the :obj:`positives` vector and i-th entry
            in the :obj:`negatives` entry should correspond to the same
            entity (*.e.g*, user), as the BPR is a personalized ranking loss.

        Args:
            positives (Tensor): The vector of positive-pair rankings.
            negatives (Tensor): The vector of negative-pair rankings.
            parameters (Tensor, optional): The tensor of parameters which
                should be used for :math:`L_2` regularization
                (default: :obj:`None`).
        """
        n_pairs = positives.size(0)
        log_prob = F.logsigmoid(positives - negatives).sum()
        regularization = 0

        if self.lambda_reg != 0:
            regularization = self.lambda_reg * parameters.norm(p=2).pow(2)

        return (-log_prob + regularization) / n_pairs

If you only trained the model to minimise loss however, your model would perform pretty poorly. Why so? Assign every node the same embedding, and you would technically be minimizing the loss. The solution to that usually involves negative sampling, paying attention to the nodes that do not have an edge between them and hence must not be similar.
Here are some negative sampling approaches for you to try out. We've done the first one, you do the rest -
1. Random, no positive check: for each positive edge coming from a playlist $p_i$, randomly draw a track $t_j$ from the full set of track nodes such that ($p_i$, $t_j$) is the negative edge. For computational efficiency, we don't check if ($p_i$, $t_j$) is actually a negative edge, though probabilistically it is very likely.
2. Random, positive check: for each positive edge coming from a playlist $p_i$, randomly draw a track $t_j$ from the full set of track nodes such that ($p_i$, $t_j$) is the negative edge. We ensure that ($p_i$, $t_j$) is not a positive edge.
3. Hard: for each positive edge coming from a playlist $p_i$, randomly draw a track $t_j$ from the top $k$ proportion of tracks, ranked by dot product similarity to $p_i$. For epoch 0, $k = 1$ and we lower it at each subsequent iteration.

In [28]:
def sample_negative_edges_nocheck(data, num_playlists, num_tracks, device = None):
  # note computationally inefficient to check that these are indeed negative edges
  playlists = data.edge_label_index[0, :]
  tracks = torch.randint(num_playlists, num_playlists + num_tracks - 1, size = data.edge_label_index[1, :].size())

  if playlists.get_device() != -1: # on gpu
    tracks = tracks.to(device)

  neg_edge_index = torch.stack((playlists, tracks), dim = 0)
  neg_edge_label = torch.zeros(neg_edge_index.shape[1])

  if neg_edge_index.get_device() != -1: # on gpu
    neg_edge_label = neg_edge_label.to(device)

  return neg_edge_index, neg_edge_label

def sample_negative_edges(data, num_playlists, num_tracks, device=None):
    positive_playlists, positive_tracks = data.edge_label_index

    # Create a mask tensor with the shape (num_playlists, num_tracks)
    mask = torch.zeros(num_playlists, num_tracks, device=device, dtype=torch.bool)
    mask[positive_playlists, positive_tracks - num_playlists] = True

    # Flatten the mask tensor and get the indices of the negative edges
    flat_mask = mask.flatten()
    negative_indices = torch.where(~flat_mask)[0]

    # Sample negative edges from the negative_indices tensor
    sampled_negative_indices = negative_indices[
        torch.randint(0, negative_indices.size(0), size=(positive_playlists.size(0),), device=device)
    ]

    # Convert the indices back to playlists and tracks tensors
    playlists = torch.floor_divide(sampled_negative_indices, num_tracks)
    tracks = torch.remainder(sampled_negative_indices, num_tracks)
    tracks = tracks + num_playlists

    neg_edge_index = torch.stack((playlists, tracks), dim=0)
    neg_edge_label = torch.zeros(neg_edge_index.shape[1], device=device)

    return neg_edge_index, neg_edge_label

def sample_hard_negative_edges(data, model, num_playlists, num_tracks, device=None, batch_size=500, frac_sample = 1):
    with torch.no_grad():
        embeddings = model.get_embedding(data.edge_index)
        playlists_embeddings = embeddings[:num_playlists].to(device)
        tracks_embeddings = embeddings[num_playlists:].to(device)

    positive_playlists, positive_tracks = data.edge_label_index
    num_edges = positive_playlists.size(0)

    # Create a boolean mask for all the positive edges
    positive_mask = torch.zeros(num_playlists, num_tracks, device=device, dtype=torch.bool)
    positive_mask[positive_playlists, positive_tracks - num_playlists] = True

    neg_edges_list = []
    neg_edge_label_list = []

    for batch_start in range(0, num_edges, batch_size):
        batch_end = min(batch_start + batch_size, num_edges)

        batch_scores = torch.matmul(
            playlists_embeddings[positive_playlists[batch_start:batch_end]], tracks_embeddings.t()
        )

        # Set the scores of the positive edges to negative infinity
        batch_scores[positive_mask[positive_playlists[batch_start:batch_end]]] = -float("inf")

        # Select the top k highest scoring negative edges for each playlist in the current batch
        # do 0.99 to filter out all pos edges which will be at the end
        _, top_indices = torch.topk(batch_scores, int(frac_sample * 0.99 * num_tracks), dim=1)
        selected_indices = torch.randint(0, int(frac_sample * 0.99 *num_tracks), size = (batch_end - batch_start, ))
        top_indices_selected = top_indices[torch.arange(batch_end - batch_start), selected_indices] + n_playlists

        # Create the negative edges tensor for the current batch
        neg_edges_batch = torch.stack(
            (positive_playlists[batch_start:batch_end], top_indices_selected), dim=0
        )
        neg_edge_label_batch = torch.zeros(neg_edges_batch.shape[1], device=device)

        neg_edges_list.append(neg_edges_batch)
        neg_edge_label_list.append(neg_edge_label_batch)

    # Concatenate the batch tensors
    neg_edges = torch.cat(neg_edges_list, dim=1)
    neg_edge_label = torch.cat(neg_edge_label_list)

    return neg_edges, neg_edge_label

Using only a loss seems like a bad idea however. Most of the time it can be quite hard to interpret loss alone. So let's just also keep an eye on the Recall@K, a popular one for RecSys.
For a playlist $i$, $P^k_i$ represents the set of the top $k$ predicted tracks for $i$ and $T_i$ the ground truth of connected tracks to playlist $i$, then we calculate
$$
\text{recall}^k_i = \frac{| P^k_i \cap T_i | }{|T_i|}.
$$
If $T_i = 0$, then we assign this value to 1 (We don't really want to run into ZeroDivisionError). Note, if $T_i \subset P_i^k$, then the recall is equal to 1. Hence, our choice of $k$ matters a lot.

You'll probably want to try out multiple values of k but maybe you can start in the ballpark of 2-3% of the total tracks
Note: When you evaluate this metric on the val or test set, make sure to filter the message passing edges from consideration, as the model can directly observe these.

In [29]:
def recall_at_k(data, model, k = 300, batch_size = 64, device = None):
    with torch.no_grad():
        embeddings = model.get_embedding(data.edge_index)
        playlists_embeddings = embeddings[:n_playlists]
        tracks_embeddings = embeddings[n_playlists:]

    hits_list = []
    relevant_counts_list = []

    for batch_start in range(0, n_playlists, batch_size):
        batch_end = min(batch_start + batch_size, n_playlists)
        batch_playlists_embeddings = playlists_embeddings[batch_start:batch_end]

        # Calculate scores for all possible item pairs
        scores = torch.matmul(batch_playlists_embeddings, tracks_embeddings.t())

        # Set the scores of message passing edges to negative infinity
        mp_indices = ((data.edge_index[0] >= batch_start) & (data.edge_index[0] < batch_end)).nonzero(as_tuple=True)[0]
        scores[data.edge_index[0, mp_indices] - batch_start, data.edge_index[1, mp_indices] - n_playlists] = -float("inf")

        # Find the top k highest scoring items for each playlist in the batch
        _, top_k_indices = torch.topk(scores, k, dim=1)

        # Ground truth supervision edges
        ground_truth_edges = data.edge_label_index

        # Create a mask to indicate if the top k items are in the ground truth supervision edges
        mask = torch.zeros(scores.shape, device=device, dtype=torch.bool)
        gt_indices = ((ground_truth_edges[0] >= batch_start) & (ground_truth_edges[0] < batch_end)).nonzero(as_tuple=True)[0]
        mask[ground_truth_edges[0, gt_indices] - batch_start, ground_truth_edges[1, gt_indices] - n_playlists] = True

        # Check how many of the top k items are in the ground truth supervision edges
        hits = mask.gather(1, top_k_indices).sum(dim=1)
        hits_list.append(hits)

        # Calculate the total number of relevant items for each playlist in the batch
        relevant_counts = torch.bincount(ground_truth_edges[0, gt_indices] - batch_start, minlength=batch_end - batch_start)
        relevant_counts_list.append(relevant_counts)

    # Compute recall@k
    hits_tensor = torch.cat(hits_list, dim=0)
    relevant_counts_tensor = torch.cat(relevant_counts_list, dim=0)
    # Handle division by zero case
    recall_at_k = torch.where(
        relevant_counts_tensor != 0,
        hits_tensor.true_divide(relevant_counts_tensor),
        torch.ones_like(hits_tensor)
    )
    # take average
    recall_at_k = torch.mean(recall_at_k)

    if recall_at_k.numel() == 1:
        return recall_at_k.item()
    else:
        raise ValueError("recall_at_k contains more than one item.")

Finally, let's calculate the ROC AUC score for the binary predictions, which provides a measure of the efficiency of our model at distinguishing `true' track-playlist edges from non-existing ones. The higher this score, the better (a perfect score is achieved when ROC AUC = 1). This is done for you here -

In [30]:
def metrics(labels, preds):
  roc = roc_auc_score(labels.flatten().cpu().numpy(), preds.flatten().data.cpu().numpy())
  return roc

Finally. It's time to write your train and test functions for training and evaluation.

In [40]:
# Train
def train(datasets, model, optimizer, loss_fn, args, neg_samp = "random"):
  print(f"Beginning training for {model.name}")

  train_data = datasets["train"]
  val_data = datasets["val"]

  stats = {
      'train': {
        'loss': [],
        'roc' : []
      },
      'val': {
        'loss': [],
        'recall': [],
        'roc' : []
      }

  }
  val_neg_edge, val_neg_label = None, None
  for epoch in range(args["epochs"]): # loop over each epoch
    model.train()
    optimizer.zero_grad()

    # obtain negative sample
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(train_data, n_playlists, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            train_data, model, n_playlists, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # calculate embedding
    embed = model.get_embedding(train_data.edge_index)
    # calculate pos, negative scores using embedding
    pos_scores = model.predict_link_embedding(embed, train_data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)

    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((train_data.edge_label, neg_edge_label), dim = 0)

    # calculate loss function
    if loss_fn == "BCE":
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    train_roc = metrics(labels, scores)

    loss.backward()
    optimizer.step()

    val_loss, val_roc, val_neg_edge, val_neg_label = test(
        model, val_data, loss_fn, neg_samp, args, epoch, val_neg_edge, val_neg_label
    )

    stats['train']['loss'].append(loss)
    stats['train']['roc'].append(train_roc)
    stats['val']['loss'].append(val_loss)
    stats['val']['roc'].append(val_roc)

    print(f"Epoch {epoch}; Train loss {loss}; Val loss {val_loss}; Train ROC {train_roc}; Val ROC {val_roc}")

    if epoch % 10 == 0:
      # calculate recall @ K
      val_recall = recall_at_k(val_data, model, k = 300, device = args["device"])
      print(f"Val recall {val_recall}")
      stats['val']['recall'].append(val_recall)

    if epoch % 20 == 0:

      # save embeddings for future visualization
      path = os.path.join("model_embeddings", model.name)
      if not os.path.exists(path):
        os.makedirs(path)
      torch.save(model.embedding.weight, os.path.join("model_embeddings", model.name, f"{model.name}_{loss_fn}_{neg_samp}_{epoch}.pt"))

  pickle.dump(stats, open(f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pkl", "wb"))
  return stats

def test(model, data, loss_fn, neg_samp, args, epoch = 0, neg_edge_index = None, neg_edge_label = None):

  model.eval()
  with torch.no_grad(): # want to save RAM

    # conduct negative sampling
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(data, n_playlists, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0 or neg_edge_index is None:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            data, model, n_playlists, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # obtain model embedding
    embed = model.get_embedding(data.edge_index)
    # calculate pos, neg scores using embedding
    pos_scores = model.predict_link_embedding(embed, data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)
    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((data.edge_label, neg_edge_label), dim = 0)
    # calculate loss
    if loss_fn == "BCE":
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    roc = metrics(labels, scores)

  return loss, roc, neg_edge_index, neg_edge_label

## Start your training

Maybe put your splits in a dict for easy access?

In [32]:
datasets = {
    'train':train_split,
    'val':val_split,
    'test': test_split
}

In [33]:
training_args = {
    'device' : 'cuda' if torch.cuda.is_available() else 'cpu',  # For the love of God, please do use a GPU though.
    'num_layers' :  3,
    'emb_size' : 64,
    'weight_decay': 1e-5,
    'lr': 0.01,
    'epochs': 301
}

Initialise your model with the parameters and nodes and whatever.
The Adam optimizer should work just fine for the task, so initialise that well as well.

In [34]:
# initialize model and and optimizer
num_nodes = n_playlists + n_tracks
model = GCN(
    num_nodes = num_nodes, num_layers = training_args['num_layers'],
    embedding_dim = training_args["emb_size"], conv_layer = "LGC"
)
optimizer = torch.optim.Adam(model.parameters(), lr=training_args['lr'], weight_decay=training_args['weight_decay'])

It's time to put this data on a GPU.

In [36]:
playlists_idx = torch.Tensor(playlists).type(torch.int64).to(training_args["device"])
tracks_idx =torch.Tensor(tracks).type(torch.int64).to(training_args["device"])
datasets['train'].to(training_args['device'])
datasets['val'].to(training_args['device'])
datasets['test'].to(training_args['device'])
model.to(training_args["device"])

GCN(15720, 64, num_layers=3)

In [37]:
# create directory to save model_stats
MODEL_STATS_DIR = "model_stats"
if not os.path.exists(MODEL_STATS_DIR):
  os.makedirs(MODEL_STATS_DIR)

In [41]:
train(datasets, model, optimizer, "BPR", training_args, neg_samp = "random")

Beginning training for LGCN_LGC_3_e64_nodes15720_
Epoch 0; Train loss 0.6931236982345581; Val loss 0.6928630471229553; Train ROC 0.8702160849953565; Val ROC 0.8717797894578827
Val recall 0.4857821762561798
Epoch 1; Train loss 0.6928468942642212; Val loss 0.6915112137794495; Train ROC 0.8877885140693493; Val ROC 0.8693705630858567
Epoch 2; Train loss 0.6914324760437012; Val loss 0.6882978677749634; Train ROC 0.8821124433905835; Val ROC 0.866725034959949
Epoch 3; Train loss 0.6880381107330322; Val loss 0.6826019287109375; Train ROC 0.8822445885885494; Val ROC 0.8692822820071768
Epoch 4; Train loss 0.6821249127388; Val loss 0.6741589307785034; Train ROC 0.8824766573768758; Val ROC 0.8702847625547763
Epoch 5; Train loss 0.6733341813087463; Val loss 0.6629157662391663; Train ROC 0.8822659871109941; Val ROC 0.8681156879937503
Epoch 6; Train loss 0.6614587903022766; Val loss 0.6486085653305054; Train ROC 0.8823733994879279; Val ROC 0.8681933964095525
Epoch 7; Train loss 0.6465133428573608; Va

{'train': {'loss': [tensor(0.6931, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6928, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6914, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6880, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6821, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6733, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6615, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6465, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6285, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6074, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.5838, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.5584, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.5314, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.5038, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.4770, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.4496, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.4246, devic

In [44]:
test(model, datasets['test'], "BPR", neg_samp = "random", args=training_args)

(tensor(0.2821, device='cuda:0'),
 0.9173670171259903,
 tensor([[ 2430,  8856,  4467,  ...,  7205,  1476,  3158],
         [14253, 11805, 15036,  ..., 14861, 13440, 13529]], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'))

In [45]:
def init_model(conv_layer, args, alpha = False):
  num_nodes = n_playlists + n_tracks
  model = GCN(
      num_nodes = num_nodes, num_layers = args['num_layers'],
      embedding_dim = args["emb_size"], conv_layer = conv_layer,
      alpha_learnable = alpha
  )
  model.to(args["device"])
  optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
  return model, optimizer

In [46]:
## For example:

# using BPR loss
loss_fn = "BPR"

# using hard sampling
neg_samp = "hard"

# for LGConv:
training_args['epochs'] = 301
training_args['num_layers'] = 4
model, optimizer = init_model("LGC", training_args)
lgc_stats_hard = train(datasets, model, optimizer, loss_fn, training_args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

Beginning training for LGCN_LGC_4_e64_nodes15720_
Epoch 0; Train loss 0.6931362152099609; Val loss 0.6931349635124207; Train ROC 0.6313463344373743; Val ROC 0.7565399200854975
Val recall 0.3842117190361023
Epoch 1; Train loss 0.6931306719779968; Val loss 0.6929486989974976; Train ROC 0.8622771591233875; Val ROC 0.8299746437062402
Epoch 2; Train loss 0.6929353475570679; Val loss 0.6919710040092468; Train ROC 0.8462058723873402; Val ROC 0.8219991887118722
Epoch 3; Train loss 0.6918988823890686; Val loss 0.6894924640655518; Train ROC 0.8355215425444318; Val ROC 0.8211186955186817
Epoch 4; Train loss 0.6892750859260559; Val loss 0.685011088848114; Train ROC 0.8341301212581467; Val ROC 0.8206672900709813
Epoch 5; Train loss 0.6846597194671631; Val loss 0.6784350872039795; Train ROC 0.8317401287476761; Val ROC 0.8176610739916004
Epoch 6; Train loss 0.6775448322296143; Val loss 0.669202983379364; Train ROC 0.8313078036763918; Val ROC 0.8171446183988965
Epoch 7; Train loss 0.6677597761154175; 

## Section 3 - Training Visualisation

In [47]:
def plot_train_val_loss(stats_dict):
    train_loss = stats_dict['train_loss']
    val_loss = stats_dict['val_loss']

    # Create a range of epochs for the x-axis
    epochs = range(1, len(train_loss) + 1)

    # Plot the training and validation loss
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_loss, label='Training Loss')
    plt.plot(epochs, val_loss, label='Validation Loss')

    # Customize the plot
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Show the plot
    plt.show()