In this project, we will train a GNN to perform link prediction on a heterogenous graph from the Spotify Million Playlists dataset.

# Import libraries

In [13]:
!pip install numpy iprogress tqdm networkx torch_geometric torch_scatter
import os
import json
import torch
import numpy as np
from tqdm import tqdm
import networkx as nx
import torch_geometric
import torch_scatter



In [14]:
# temporary imports (delete later)
from pprint import pprint
import pickle
import sys
import deepsnap

# http://192.168.0.103:8888/?token=klobasa

# Configuration

In [9]:
# config
base = "spotify_million_playlist_dataset"
pickles = base + "/pickles"

# full dataset
dataset_path = base + "/data"
pickled_graph = pickles + "/G.pkl"
pickled_dataset = pickles + "/dataset.pkl"

# example dataset (override above)
dataset_path = base + "/example"
pickled_graph = pickles + "/G_example.pkl"
pickled_dataset = pickles + "/dataset_example.pkl"

# Load datasets

In [None]:
G = None
def load_graph(dataset_path=dataset_path):
    global G
    filenames = os.listdir(dataset_path)
    G = nx.Graph()
    for i in tqdm(range(len(filenames)), unit="files"):
        with open(os.path.join(dataset_path, filenames[i])) as json_file:
            playlists = json.load(json_file)["playlists"]
            for playlist in playlists:
                playlist_name = f"spotify:playlist:{playlist['pid']}"
                G.add_node(playlist_name, node_type="playlist", num_followers=playlist["num_followers"])
                for track in playlist["tracks"]:
                    G.add_node(track["track_uri"], node_type="track", duration=track["duration_ms"])
                    G.add_node(track["album_uri"], node_type="album")
                    G.add_node(track["artist_uri"], node_type="artist")

                    G.add_edge(track["track_uri"], playlist_name, edge_type="track-playlist")
                    G.add_edge(track["track_uri"], track["album_uri"], edge_type="track-album")
                    G.add_edge(track["track_uri"], track["artist_uri"], edge_type="track-artist")
    return G

def nx2hetero(graph_getter):
	G = graph_getter()
	node_types = set([node[1]["node_type"] for node in G.nodes(data=True)])
	nodes_by_type = dict()
	for node_type in node_types:
		nodes_by_type[node_type] = [node[1] for node in list(G.nodes(data=True)) if node[1]["node_type"] == node_type][:10]
	nodes_by_type

	# build node index
	playlists = []
	tracks = []
	num_artists = 0
	num_albums = 0
	for node in G.nodes(data=True):
		t = node[1]["node_type"]
		if t == "playlist":
			playlists += [node[1]["num_followers"]]
		elif t == "track":
			tracks += [node[1]["duration"]]
		elif t == "artist":
			num_artists += 1
		elif t == "album":
			num_albums += 1

	# build edge_index
	playlist_track = []
	album_track = []
	artist_track = []
	for node in G.edges(data=True):
		if G[node[0]][node[1]]["edge_type"] == "track-playlist":
			playlist_track += [(name2int[node[0]], name2int[node[1]])]
		elif G[node[0]][node[1]]["edge_type"] == "track-album":
			album_track += [(name2int[node[0]], name2int[node[1]])]
		elif G[node[0]][node[1]]["edge_type"] == "track-artist":
			artist_track += [(name2int[node[0]], name2int[node[1]])]

	# construct HeteroData
	hetero = HeteroData()
	# hetero["playlist"].x = torch.tensor(playlists)
	# hetero["track"].x = torch.tensor(tracks)
	# hetero["artist"].x = torch.tensor([1 for _ in range(num_artists)])
	# hetero["album"].x = torch.tensor([1 for _ in range(num_albums)])
	hetero["nodes"].x = torch.tensor([1 for _ in range(len(G.nodes))])

	hetero["playlist", "contains", "track"].edge_index = torch.tensor(playlist_track).t()
	hetero["album", "contains", "track"].edge_index = torch.tensor(album_track).t()
	hetero["artist", "writes", "track"].edge_index = torch.tensor(artist_track).t()
	return hetero

def graph_to_dataset(graph_getter):
    # hetero = nx2hetero(graph_getter)
    print("Converting graph to DeepSnap dataset ...")
    graph = graph_getter()
    dataset = deepsnap.dataset.GraphDataset(
        [graph], task="link_pred", edge_train_mode="disjoint"
    )
    return dataset.split(
        transductive=True, split_ratio=[0.7, 0.1, 0.2]
    )

def get_cached(var, pickled_filename, generator, *args, ignore_cache=True):
    global G
    if not ignore_cache and var in globals():
        print(f"{var} already loaded :)")
        return globals()[var]
    elif not ignore_cache and os.path.exists(pickled_filename):
        print(f"Loading {var} from pickle ...")
        return pickle.load(open(pickled_filename, "rb"))
    else:
        print(f"Pickled {var} not found, generating anew ...")
        obj = generator(*args)
        G = obj
        pickle.dump(obj, open(pickled_filename, "wb"))
        print(f"{var} generated, pickle saved to {pickled_filename}")
        return obj

import time
start = time.time()
# hetero = get_cached("hetero", pickles + "/HeteroGraph_nodesHaveNames.pkl", nx2hetero, lambda: get_cached("G", pickled_graph, load_graph))
dataset = get_cached("dataset", pickled_dataset, graph_to_dataset, lambda: get_cached("G", pickled_graph, load_graph))
dataset_train, dataset_val, dataset_test = dataset
print("Finished loading data.")

with open("status.txt", "w") as f:
	f.write(f"Finished loading data in {time.time() - start}.")

In [14]:
import gc
try:
    del G  # fuckup variable here
except: pass
gc.collect()

68992

In [11]:
!ls spotify_million_playlist_dataset/pickles/HeteroGraph_nodesHaveNames.pkl

chad_G.bak		       HeteroGraph_nodesHaveNames.pkl
chad_G_with_edge_types.pkl     top-G-10000.pkl
dataset_example.pkl	       top-G-1000.pkl
G_example.pkl		       top-G-100.pkl
G_no_node_type.pkl	       top-G-5000.pkl
G.pkl			       top-G-500.pkl
HeteroData.pkl		       yon_dataset_example.pkl
HeteroGraph_nodesAreNodes.pkl  yon_G_example.pkl


In [15]:

dataset = pickle.load(open("spotify_million_playlist_dataset/pickles/HeteroGraph_nodesHaveNames.pkl", "rb"))

# Processing

In [16]:
class LightGCNConv(torch_geometric.nn.conv.MessagePassing):
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(LightGCNConv, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

    def forward(self, x, edge_index, size = None):
        out = self.propagate(edge_index, x=(x, x))
        return out

    def message(self, x_j):
        out = x_j
        return out

    def aggregate(self, inputs, index, dim_size = None):
        node_dim = self.node_dim
        out = torch_scatter.scatter(inputs, index, dim=node_dim, reduce='mean')
        return out

In [None]:
dataset.node_label_index

In [None]:
k = len(phrase_encoding)
res = dict()
for offset in range(len(words) - k + 1):
	phrase = [words[offset+i] if phrase_encoding[i] else "/" for i in range(k)]
	phrase = "-".join(phrase)
	if phrase in res: res[phrase] += 1
	else: res[phrase] = 1
return res

In [24]:
class LightGCN(torch.nn.Module):
    def __init__(self, train_data, num_layers, emb_size=16, initialize_with_words=False):
        super(LightGCN, self).__init__()
        self.convs = torch.nn.ModuleList()
        assert (num_layers >= 1), 'Number of layers is not >=1'
        for l in range(num_layers):
            self.convs.append(LightGCNConv(input_dim, input_dim))

        # Initialize using custom embeddings if provided
        num_nodes = train_data.node_label_index.size()[0]
        self.embeddings = torch.nn.Embedding(num_nodes, emb_size)
        if initialize_with_words:
            self.embeddings.weight.data.copy_(train_data.node_features)
        
        self.loss_fn = torch.nn.BCELoss()
        self.num_layers = num_layers
        self.emb_size = emb_size
        self.num_modes = num_nodes

    def forward(self, data):
        edge_index, edge_label_index, node_label_index = data.edge_index, data.edge_label_index, data.node_label_index
        layer_embeddings = []
        
        x = self.embeddings(node_label_index)
        mean_layer = x

        # We take an average of ever layer's node embeddings
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            mean_layer += x

        mean_layer /= 4

        # Prediction head is simply dot product
        nodes_first = torch.index_select(x, 0, edge_label_index[0,:].long())
        nodes_second = torch.index_select(x, 0, edge_label_index[1,:].long())

        # Since we don't want a rank output, we create a sigmoid of the dot product
        out = torch.sum(nodes_first * nodes_second, dim=-1) # FOR RANKING
        pred = torch.sigmoid(out)

        return torch.flatten(pred)

    def loss(self, pred, label):
        return self.loss_fn(pred, label)

In [25]:
args = {
    'device' : 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_layers' : 3,
    'emb_size' : 32,
    'epochs' : 1,
    'weight_decay': 1e-5,
    'lr': 0.01,
    'epochs': 300
}

datasets = {
    'train': dataset,
    'val': dataset,
    'test': dataset,
}

input_dim = datasets['train'].num_node_features
print(input_dim, args)

{'playlist': 1, 'track': 1, 'artist': 1, 'album': 1} {'device': 'cuda', 'num_layers': 3, 'emb_size': 32, 'epochs': 300, 'weight_decay': 1e-05, 'lr': 0.01}


In [26]:
datasets['train'].to(args['device'])
datasets['val'].to(args['device'])
datasets['test'].to(args['device'])

HeteroData(
  [1mplaylist[0m={ x=[1000] },
  [1mtrack[0m={ x=[35289] },
  [1martist[0m={ x=[10091] },
  [1malbum[0m={ x=[20469] },
  [1m(playlist, contains, track)[0m={ edge_index=[2, 66331] },
  [1m(album, contains, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, writes, track)[0m={ edge_index=[2, 35289] }
)

In [27]:
losses = []

def train(model, optimizer, args):
    val_max = 0
    best_model = model

    for epoch in range(1, args['epochs'] + 1):
        datasets['train'].to(args["device"])
        model.train()
        optimizer.zero_grad()
        pred = model(datasets['train'])
        loss = model.loss(pred, datasets['train'].edge_label.type(pred.dtype))

        loss.backward()
        optimizer.step()

        log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}, Loss: {:.5f}, Val Loss: {:.5f}'
        score_train, train_loss = test(model, 'train', args)
        score_val, val_loss = test(model, 'val', args)
        score_test, test_loss = test(model, 'test', args)

        losses.append((train_loss, val_loss))

        print(log.format(epoch, score_train, score_val, score_test, train_loss, val_loss))
        if val_max < score_val:
            val_max = score_val
            best_model = copy.deepcopy(model)

    return best_model

def test(model, mode, args):
    model.eval()
    score = 0
    loss_score = 0

    data = datasets[mode]
    data.to(args["device"])

    pred = model(data)
    loss = model.loss(pred, data.edge_label.type(pred.dtype))
    score += roc_auc_score(data.edge_label.flatten().cpu().numpy(), pred.flatten().data.cpu().numpy())
    loss_score += loss.item()

    return score, loss_score

In [28]:
model = LightGCN(datasets['train'], args['num_layers'], emb_size=args['emb_size']).to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

best_model = train(model, optimizer, args)
log = "Train: {:.4f}, Val: {:.4f}, Test: {:.4f}, Val Loss: {:.5f}, Test Loss: {:.5f}"
best_train_roc, train_loss = test(best_model, 'train', args)
best_val_roc, val_loss = test(best_model, 'val', args)
best_test_roc, test_loss = test(best_model, 'test', args)
print(log.format(best_train_roc, best_val_roc, best_test_roc, train_loss, val_loss, test_loss))

AttributeError: 'HeteroData' has no attribute 'node_label_index'