In this project, we will train a GNN to perform link prediction on a heterogenous graph from the Spotify Million Playlists dataset.

# Import libraries

In [1]:
import sys  
sys.path.insert(0, '/home/yon/jupyter-server/mlg/src/')

import loader
import config
import model as M
import preprocessing
from pprint import pprint
import torch
import random
import torch_geometric
import numpy as np
import time

# Test Run

In [5]:
!rm spotify_million_playlist_dataset/pickles/G_example.pkl

In [6]:
ghetero = loader.get_ghetero(True, config)
data_train, data_val, data_test = loader.get_datasets(True, config)

Pickled ghetero not found, generating anew ...
Pickled G not found, generating anew ...


100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.26files/s]


G generated, pickle saved to spotify_million_playlist_dataset/pickles/G_example.pkl


TypeError: 'NodeDataView' object is not an iterator

In [None]:
# create training mask for playlist nodes
train_mask = torch.zeros(ghetero["playlist"].x.shape[0], dtype=torch.bool)
train_mask[torch.randperm(train_mask.shape[0])[:int(train_mask.shape[0]*0.8)]] = True

ghetero["playlist"].train_mask = train_mask

ghetero["playlist"].y = torch.LongTensor([1]*ghetero["playlist"].x.shape[0])

model = M.HeteroModel(64, ghetero.x_dict, ghetero.metadata())
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

edge_label_index = data_train["track", "contains", "playlist"].edge_label_index
edge_label = data_train["track", "contains", "playlist"].edge_label
train_loader = torch_geometric.loader.LinkNeighborLoader(
    data=data_train,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("track", "contains", "playlist"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

In [None]:
data_train

HeteroData(
  [1mplaylist[0m={ x=[1000, 1] },
  [1mtrack[0m={ x=[35289, 1] },
  [1martist[0m={ x=[10091, 1] },
  [1malbum[0m={ x=[20469, 1] },
  [1m(track, contains, playlist)[0m={
    edge_index=[2, 37146],
    edge_label=[15919],
    edge_label_index=[2, 15919]
  },
  [1m(track, includes, album)[0m={ edge_index=[2, 35289] },
  [1m(track, authors, artist)[0m={ edge_index=[2, 35289] },
  [1m(playlist, rev_contains, track)[0m={ edge_index=[2, 37146] },
  [1m(album, rev_includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, rev_authors, track)[0m={ edge_index=[2, 35289] }
)

In [None]:
import tqdm
epoch = 100
for i in range(epoch):
    loss = M.train(model, train_loader, optimizer, batch_wrapper=tqdm.tqdm)
    print(f"Epoch {i+1}/{epoch}, Loss: {loss:.4f}")

100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 19.92it/s]


Epoch 1/100, Loss: 758.6188


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 20.09it/s]


Epoch 2/100, Loss: 758.5325


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 20.25it/s]


Epoch 3/100, Loss: 758.5125


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 19.80it/s]


Epoch 4/100, Loss: 758.5638


100%|██████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 20.11it/s]


Epoch 5/100, Loss: 758.5302


 22%|███████████████▉                                                       | 28/125 [00:01<00:04, 19.68it/s]


KeyboardInterrupt: 