In this project, we will train a GNN to perform link prediction on a heterogenous graph from the Spotify Million Playlists dataset.

# Import libraries

In [75]:
import sys  
sys.path.insert(0, '/home/yon/jupyter-server/mlg/src/')

import loader
import config
import model as M
import preprocessing
from pprint import pprint
import torch
import random
import torch_geometric
import numpy as np
import time

# Model

In [76]:
from torcheval.metrics import BinaryAccuracy

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = torch_geometric.nn.SAGEConv((-1, -1), hidden_channels, normalize=True)
        self.conv2 = torch_geometric.nn.SAGEConv((-1, -1), hidden_channels, normalize=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

class LinkPredictor(torch.nn.Module):
    def forward(self, x_track, x_playlist, track_playlist_edge):
        track_embedding = x_track[track_playlist_edge[0]]
        playlist_embedding = x_playlist[track_playlist_edge[1]]

        #print(playlist_embedding)

        # Apply dot-product to get a prediction per supervision edge:
        return (playlist_embedding * track_embedding).sum(dim=-1)

class HeteroModel(torch.nn.Module):
    def __init__(self, hidden_channels, node_features, metadata):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:

        self.node_lin = {
            k: torch.nn.Linear(v.shape[1], hidden_channels) for k, v in node_features.items()
        }

        for _, v in self.node_lin.items():
            torch.nn.init.xavier_uniform_(v.weight)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = torch_geometric.nn.to_hetero(self.gnn, metadata=metadata)

        self.classifier = LinkPredictor()

    def forward(self, data):
        x_dict = {
            k: self.node_lin[k](v) for k, v in data.x_dict.items()
        }

        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["track"],
            x_dict["playlist"],
            data["track", "contains", "playlist"].edge_label_index,
        )
        return pred

    def reset_parameters(self):
        for _, v in self.node_lin.items():
            torch.nn.init.xavier_uniform_(v.weight)
        self.gnn.reset_parameters()

def dummy_generator(source):
    for e in source:
        yield e

def train(model, train_loader, optimizer, batch_wrapper=dummy_generator):
    model.train()

    accuracy = 0

    total_examples = total_loss = 0
    for i, batch in enumerate(batch_wrapper(train_loader)):
        optimizer.zero_grad()
        
        out = model(batch)
        truth = batch["track", "contains", "playlist"].edge_label


        if(i % 10 == 0):
            #print(out[:10])
            #print(batch["track", "contains", "playlist"].edge_label[:10])
            pass
        loss = torch.nn.functional.mse_loss(
            out, truth
        )
        loss.backward()
        optimizer.step()

        metric = BinaryAccuracy()
        metric.update(out, truth)
        accuracy += metric.compute() * len(out)

        total_examples += len(out)
        total_loss += float(loss) * len(out)

    return total_loss / total_examples, accuracy / total_examples

# Test Run

In [77]:
!rm spotify_million_playlist_dataset/pickles/G_example.pkl

In [78]:
ghetero = loader.get_ghetero(True, config)
data_train, data_val, data_test = loader.get_datasets(True, config)

Pickled ghetero not found, generating anew ...
Pickled G not found, generating anew ...


100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.07s/files]


G generated, pickle saved to spotify_million_playlist_dataset/pickles/G_example.pkl
[1, False, 39, 40, 11, 9631768, 29]
ghetero generated, pickle saved to spotify_million_playlist_dataset/pickles/ghetero_example.pkl
Pickled datasets not found, generating anew ...
Loading ghetero from pickle ...
datasets generated, pickle saved to spotify_million_playlist_dataset/pickles/datasets_example.pkl


In [79]:
# create training mask for playlist nodes
train_mask = torch.zeros(ghetero["playlist"].x.shape[0], dtype=torch.bool)
train_mask[torch.randperm(train_mask.shape[0])[:int(train_mask.shape[0]*0.8)]] = True

ghetero["playlist"].train_mask = train_mask

ghetero["playlist"].y = torch.LongTensor([1]*ghetero["playlist"].x.shape[0])

model = HeteroModel(64, ghetero.x_dict, ghetero.metadata())
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

edge_label_index = data_train["track", "contains", "playlist"].edge_label_index
edge_label = data_train["track", "contains", "playlist"].edge_label
train_loader = torch_geometric.loader.LinkNeighborLoader(
    data=data_train,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("track", "contains", "playlist"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

In [80]:
data_train

HeteroData(
  [1mplaylist[0m={ x=[1000, 1] },
  [1mtrack[0m={ x=[35289, 1] },
  [1martist[0m={ x=[10091, 1] },
  [1malbum[0m={ x=[20469, 1] },
  [1m(track, contains, playlist)[0m={
    edge_index=[2, 37146],
    edge_label=[15919],
    edge_label_index=[2, 15919]
  },
  [1m(track, includes, album)[0m={ edge_index=[2, 35289] },
  [1m(track, authors, artist)[0m={ edge_index=[2, 35289] },
  [1m(playlist, rev_contains, track)[0m={ edge_index=[2, 37146] },
  [1m(album, rev_includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, rev_authors, track)[0m={ edge_index=[2, 35289] }
)

In [81]:
import tqdm
epoch = 100

for i in range(epoch):
    loss, accuracy = train(model, train_loader, optimizer, batch_wrapper=tqdm.tqdm)
    print(f"Epoch {i+1}/{epoch}, Loss: {loss:.4f}, Accuracy {accuracy:.4f}")

100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:05<00:00, 24.86it/s]


Epoch 1/100, Loss: 0.2746, Accuracy 0.6667


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.27it/s]


Epoch 2/100, Loss: 0.2120, Accuracy 0.6809


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.06it/s]


Epoch 3/100, Loss: 0.2087, Accuracy 0.6950


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.25it/s]


Epoch 4/100, Loss: 0.2060, Accuracy 0.6809


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.61it/s]


Epoch 5/100, Loss: 0.2047, Accuracy 0.6667


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.26it/s]


Epoch 6/100, Loss: 0.2049, Accuracy 0.7163


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.09it/s]


Epoch 7/100, Loss: 0.2038, Accuracy 0.6738


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.25it/s]


Epoch 8/100, Loss: 0.2025, Accuracy 0.7092


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.38it/s]


Epoch 9/100, Loss: 0.1998, Accuracy 0.6879


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.28it/s]


Epoch 10/100, Loss: 0.1984, Accuracy 0.6950


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.17it/s]


Epoch 11/100, Loss: 0.1989, Accuracy 0.6596


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.37it/s]


Epoch 12/100, Loss: 0.1981, Accuracy 0.7021


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.18it/s]


Epoch 13/100, Loss: 0.1978, Accuracy 0.6809


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:05<00:00, 25.00it/s]


Epoch 14/100, Loss: 0.1961, Accuracy 0.6667


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.46it/s]


Epoch 15/100, Loss: 0.1962, Accuracy 0.6667


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:04<00:00, 25.32it/s]


Epoch 16/100, Loss: 0.1965, Accuracy 0.6950


 79%|████████████████████████████████████████████████████████▏              | 99/125 [00:03<00:01, 25.10it/s]