In this project, we will train a GNN to perform link prediction on a heterogenous graph from the Spotify Million Playlists dataset.

# Import libraries

In [5]:
!pip install numpy iprogress tqdm networkx torch_geometric
import os
import json
import numpy as np
from tqdm import tqdm
import networkx as nx
import torch_geometric



In [6]:
# temporary imports (delete later)
from pprint import pprint
import pickle
import sys
import deepsnap

# http://192.168.0.103:8888/?token=klobasa

# Configuration

In [9]:
# config
base = "spotify_million_playlist_dataset"
pickles = base + "/pickles"

# full dataset
dataset_path = base + "/data"
pickled_graph = pickles + "/G.pkl"
pickled_dataset = pickles + "/dataset.pkl"

# example dataset (override above)
dataset_path = base + "/example"
pickled_graph = pickles + "/G_example.pkl"
pickled_dataset = pickles + "/dataset_example.pkl"

# Load datasets

In [None]:
G = None
def load_graph(dataset_path=dataset_path):
    global G
    filenames = os.listdir(dataset_path)
    G = nx.Graph()
    for i in tqdm(range(len(filenames)), unit="files"):
        with open(os.path.join(dataset_path, filenames[i])) as json_file:
            playlists = json.load(json_file)["playlists"]
            for playlist in playlists:
                playlist_name = f"spotify:playlist:{playlist['pid']}"
                G.add_node(playlist_name, node_type="playlist", num_followers=playlist["num_followers"])
                for track in playlist["tracks"]:
                    G.add_node(track["track_uri"], node_type="track", duration=track["duration_ms"])
                    G.add_node(track["album_uri"], node_type="album")
                    G.add_node(track["artist_uri"], node_type="artist")

                    G.add_edge(track["track_uri"], playlist_name, edge_type="track-playlist")
                    G.add_edge(track["track_uri"], track["album_uri"], edge_type="track-album")
                    G.add_edge(track["track_uri"], track["artist_uri"], edge_type="track-artist")
    return G

def nx2hetero(graph_getter):
	G = graph_getter()
	node_types = set([node[1]["node_type"] for node in G.nodes(data=True)])
	nodes_by_type = dict()
	for node_type in node_types:
		nodes_by_type[node_type] = [node[1] for node in list(G.nodes(data=True)) if node[1]["node_type"] == node_type][:10]
	nodes_by_type

	# build node index
	playlists = []
	tracks = []
	num_artists = 0
	num_albums = 0
	for node in G.nodes(data=True):
		t = node[1]["node_type"]
		if t == "playlist":
			playlists += [node[1]["num_followers"]]
		elif t == "track":
			tracks += [node[1]["duration"]]
		elif t == "artist":
			num_artists += 1
		elif t == "album":
			num_albums += 1

	# build edge_index
	playlist_track = []
	album_track = []
	artist_track = []
	for node in G.edges(data=True):
		if G[node[0]][node[1]]["edge_type"] == "track-playlist":
			playlist_track += [(name2int[node[0]], name2int[node[1]])]
		elif G[node[0]][node[1]]["edge_type"] == "track-album":
			album_track += [(name2int[node[0]], name2int[node[1]])]
		elif G[node[0]][node[1]]["edge_type"] == "track-artist":
			artist_track += [(name2int[node[0]], name2int[node[1]])]

	# construct HeteroData
	hetero = HeteroData()
	# hetero["playlist"].x = torch.tensor(playlists)
	# hetero["track"].x = torch.tensor(tracks)
	# hetero["artist"].x = torch.tensor([1 for _ in range(num_artists)])
	# hetero["album"].x = torch.tensor([1 for _ in range(num_albums)])
	hetero["nodes"].x = torch.tensor([1 for _ in range(len(G.nodes))])

	hetero["playlist", "contains", "track"].edge_index = torch.tensor(playlist_track).t()
	hetero["album", "contains", "track"].edge_index = torch.tensor(album_track).t()
	hetero["artist", "writes", "track"].edge_index = torch.tensor(artist_track).t()
	return hetero

def graph_to_dataset(graph_getter):
    # hetero = nx2hetero(graph_getter)
    print("Converting graph to DeepSnap dataset ...")
    graph = graph_getter()
    dataset = deepsnap.dataset.GraphDataset(
        [graph], task="link_pred", edge_train_mode="disjoint"
    )
    return dataset.split(
        transductive=True, split_ratio=[0.7, 0.1, 0.2]
    )

def get_cached(var, pickled_filename, generator, *args, ignore_cache=True):
    global G
    if not ignore_cache and var in globals():
        print(f"{var} already loaded :)")
        return globals()[var]
    elif not ignore_cache and os.path.exists(pickled_filename):
        print(f"Loading {var} from pickle ...")
        return pickle.load(open(pickled_filename, "rb"))
    else:
        print(f"Pickled {var} not found, generating anew ...")
        obj = generator(*args)
        G = obj
        pickle.dump(obj, open(pickled_filename, "wb"))
        print(f"{var} generated, pickle saved to {pickled_filename}")
        return obj

import time
start = time.time()
# hetero = get_cached("hetero", pickles + "/HeteroGraph_nodesHaveNames.pkl", nx2hetero, lambda: get_cached("G", pickled_graph, load_graph))
dataset = get_cached("dataset", pickled_dataset, graph_to_dataset, lambda: get_cached("G", pickled_graph, load_graph))
dataset_train, dataset_val, dataset_test = dataset
print("Finished loading data.")

with open("status.txt", "w") as f:
	f.write(f"Finished loading data in {time.time() - start}.")

In [14]:
import gc
try:
    del G  # fuckup variable here
except: pass
gc.collect()

68992

# Preprocessing