In this project, we will train a GNN to perform link prediction on a heterogenous graph from the Spotify Million Playlists dataset.

# Import libraries

In [248]:
!pip install numpy iprogress tqdm networkx torch_geometric
import os
import json
import torch
import pickle
from tqdm import tqdm
import networkx as nx
import torch_geometric



In [249]:
# temporary imports (delete later)
from pprint import pprint
import random
import numpy as np
# http://192.168.0.103:8888/?token=klobasa

# Configuration

In [250]:
# config
base = "spotify_million_playlist_dataset"
pickles = base + "/pickles"

# full dataset
dataset_path = base + "/data"
pickled_graph = pickles + "/G.pkl"
pickled_datasets = pickles + "/datasets.pkl"
pickled_ghetero = pickles + "/ghetero.pkl"

# example dataset (override above)
dataset_path = base + "/example"
pickled_graph = pickles + "/G_example.pkl"
pickled_datasets = pickles + "/datasets_example.pkl"
pickled_ghetero = pickles + "/ghetero_example.pkl"

# Load datasets

In [251]:
def load_graph(dataset_path=dataset_path):
    """Load a nx.Graph from disk."""
    filenames = os.listdir(dataset_path)
    G = nx.DiGraph()
    for i in tqdm(range(len(filenames)), unit="files"):
        with open(os.path.join(dataset_path, filenames[i])) as json_file:
            playlists = json.load(json_file)["playlists"]
            for playlist in playlists:
                playlist_name = f"spotify:playlist:{playlist['pid']}"
                G.add_node(playlist_name, node_type="playlist", num_followers=playlist["num_followers"])
                for track in playlist["tracks"]:
                    G.add_node(track["track_uri"], node_type="track", duration=track["duration_ms"])
                    G.add_node(track["album_uri"], node_type="album")
                    G.add_node(track["artist_uri"], node_type="artist")

                    G.add_edge(track["track_uri"], playlist_name, edge_type="track-playlist")
                    G.add_edge(track["track_uri"], track["album_uri"], edge_type="track-album")
                    G.add_edge(track["track_uri"], track["artist_uri"], edge_type="track-artist")
    return G

def nx2hetero(graph_getter):
    G = graph_getter
    """Convert a nx.Graph into a torch_geometric.data.HeteroData object."""
    ids_by_type = {
        "playlist": {},
        "track": {},
        "artist": {},
        "album": {}
    }
    
    def node_id(node_type, id):
        d = ids_by_type[node_type]
        if id not in d:
            d[id] = len(d)
        return d[id]

    node_features_by_type = {
        "playlist": [],
        "track": [],
        "artist": [],
        "album": []
    }
    for node in G.nodes(data=True):
        t = node[1]["node_type"]
        node_id(t, node[0])
        if t == "playlist":
            node_features_by_type["playlist"] += [node[1]["num_followers"]]
        elif t == "track":
            node_features_by_type["track"] += [node[1]["duration"]]
        elif t == "artist":
            node_features_by_type["artist"] += [1]
        elif t == "album":
            node_features_by_type["album"] += [1]

    edge_index_by_type = {
        ("track", "contains", "playlist"): [],
        ("track", "includes", "album"): [],
        ("track", "authors", "artist"): []
    }
    for edge in G.edges(data=True):
        if G[edge[0]][edge[1]]["edge_type"] == "track-playlist":
            s_id = node_id("track", edge[0])
            d_id = node_id("playlist", edge[1])
            edge_index_by_type[("track", "contains", "playlist")] += [(s_id, d_id)]
        elif G[edge[0]][edge[1]]["edge_type"] == "track-album":
            s_id = node_id("track", edge[0])
            d_id = node_id("album", edge[1])
            edge_index_by_type[("track", "includes", "album")] += [(s_id, d_id)]
        elif G[edge[0]][edge[1]]["edge_type"] == "track-artist":
            s_id = node_id("track", edge[0])
            d_id = node_id("artist", edge[1])
            edge_index_by_type[("track", "authors", "artist")] += [(s_id, d_id)]

    # construct HeteroData
    hetero = torch_geometric.data.HeteroData()

    # add initial node features
    hetero["playlist"].x = torch.FloatTensor(node_features_by_type["playlist"]).reshape(-1,1)
    hetero["track"].x = torch.FloatTensor(node_features_by_type["track"]).reshape(-1,1)
    hetero["artist"].x = torch.FloatTensor(node_features_by_type["artist"]).reshape(-1,1)
    hetero["album"].x = torch.FloatTensor(node_features_by_type["album"]).reshape(-1,1)

    # add edge indices
    hetero["track", "contains", "playlist"].edge_index = torch.tensor(edge_index_by_type[("track", "contains", "playlist")]).t()
    hetero["track", "includes", "album"].edge_index = torch.tensor(edge_index_by_type[("track", "includes", "album")]).t()
    hetero["track", "authors", "artist"].edge_index = torch.tensor(edge_index_by_type[("track", "authors", "artist")]).t()

    # post-processing
    hetero = torch_geometric.transforms.ToUndirected()(hetero)
    hetero = torch_geometric.transforms.RemoveIsolatedNodes()(hetero)
    assert hetero.validate()
    return hetero

def get_cached(var, pickled_filename, fallback, ignore_cache=False):
    """Get a variable from cache.
    
    First, check global memory (variable `var`).
    If not found, check pickle (file `pickled_filename`).
    If not found, generate anew (use `fallback` function).
    """
    if not ignore_cache and var in globals():
        print(f"Using {var} from global memory ...")
        return globals()[var]
    elif not ignore_cache and os.path.exists(pickled_filename):
        print(f"Loading {var} from pickle ...")
        return pickle.load(open(pickled_filename, "rb"))
    else:
        print(f"Pickled {var} not found, generating anew ...")
        obj = fallback()
        pickle.dump(obj, open(pickled_filename, "wb"))
        print(f"{var} generated, pickle saved to {pickled_filename}")
        return obj

In [252]:
# !rm ../spotify_million_playlist_dataset/pickles/datasets_example.pkl
# !rm ../spotify_million_playlist_dataset/pickles/G_example.pkl
# !rm ../spotify_million_playlist_dataset/pickles/ghetero_example.pkl
# !ls ../spotify_million_playlist_dataset/pickles/
# del G
# del ghetero
# del datasets

In [253]:
def ghetero2datasets(ghetero):
    """Split the dataset into train, validation and test sets."""
    transform = torch_geometric.transforms.RandomLinkSplit(
        num_val=0.1,
        num_test=0.1,
        disjoint_train_ratio=0.3,
        neg_sampling_ratio=2.0,
        add_negative_train_samples=False,
        edge_types=("track", "contains", "playlist"),
        rev_edge_types=("playlist", "rev_contains", "track"),
    )

    return transform(ghetero)  # 3-tuple: data_train, data_val, data_test

In [254]:
get_g = lambda: get_cached("G", pickled_graph, fallback=load_graph)
get_ghetero = lambda: get_cached("ghetero", pickled_ghetero, fallback=lambda: nx2hetero(get_g()))
get_datasets = lambda: get_cached("datasets", pickled_datasets, fallback=lambda: ghetero2datasets(get_ghetero()))

ghetero = get_ghetero()
datasets = get_datasets()
data_train, data_val, data_test = datasets
print("Finished loading data.")

Using ghetero from global memory ...
Using datasets from global memory ...
Finished loading data.


# Processing

In [255]:
# create training mask for playlist nodes
train_mask = torch.zeros(ghetero["playlist"].x.shape[0], dtype=torch.bool)
train_mask[torch.randperm(train_mask.shape[0])[:int(train_mask.shape[0]*0.8)]] = True

ghetero["playlist"].train_mask = train_mask

ghetero["playlist"].y = torch.LongTensor([1]*ghetero["playlist"].x.shape[0])

In [256]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = torch_geometric.nn.SAGEConv((-1, -1), hidden_channels, normalize=True)
        self.conv2 = torch_geometric.nn.SAGEConv((-1, -1), hidden_channels, normalize=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

class LinkPredictor(torch.nn.Module):
    def forward(self, x_track, x_playlist, track_playlist_edge):
        track_embedding = x_track[track_playlist_edge[0]]
        playlist_embedding = x_playlist[track_playlist_edge[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (playlist_embedding * track_embedding).sum(dim=-1)

class HeteroModel(torch.nn.Module):
    def __init__(self, hidden_channels, node_features, metadata):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        
        self.node_lin = {
            k: torch.nn.Linear(v.shape[1], hidden_channels) for k, v in node_features.items()
        }

        for _, v in self.node_lin.items():
            torch.nn.init.xavier_uniform_(v.weight)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = torch_geometric.nn.to_hetero(self.gnn, metadata=metadata)

        self.classifier = LinkPredictor()

    def forward(self, data):
        x_dict = {
            k: self.node_lin[k](v) for k, v in data.x_dict.items()
        }
        
        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["track"],
            x_dict["playlist"],
            data["track", "contains", "playlist"].edge_label_index,
        )
        return pred

    def reset_parameters(self):
        for _, v in self.node_lin.items():
            torch.nn.init.xavier_uniform_(v.weight)
        self.gnn.reset_parameters()

model = HeteroModel(64, ghetero.x_dict, ghetero.metadata())
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [257]:
edge_label_index = data_train["track", "contains", "playlist"].edge_label_index
edge_label = data_train["track", "contains", "playlist"].edge_label
train_loader = torch_geometric.loader.LinkNeighborLoader(
    data=data_train,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("track", "contains", "playlist"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

In [258]:
def train():
    model.train()

    total_examples = total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to('cuda:0')
        batch_size = 100
        out = model(batch)
        loss = torch.nn.functional.cross_entropy(
            out, batch["track", "contains", "playlist"].edge_label
        )
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        print(f'Loss: {loss:.4f}')
        total_loss += float(loss) * batch_size

    return total_loss / total_examples

In [259]:
train()

Loss: 761.8414
Loss: 759.3619
Loss: 762.1827
Loss: 761.2241
Loss: 761.4955
Loss: 761.5721
Loss: 761.2287
Loss: 760.9959
Loss: 760.6590
Loss: 760.1034
Loss: 760.3151
Loss: 759.1165
Loss: 760.4759
Loss: 759.1173
Loss: 760.6943
Loss: 759.2424
Loss: 759.7942
Loss: 760.5875
Loss: 760.9835
Loss: 761.6185
Loss: 759.3360
Loss: 761.9995
Loss: 761.4413
Loss: 760.8193
Loss: 760.9785
Loss: 760.3000
Loss: 761.7486
Loss: 759.7693
Loss: 759.5469
Loss: 759.5927
Loss: 760.6370
Loss: 761.2018
Loss: 758.3416
Loss: 760.1608
Loss: 759.6035
Loss: 760.1858
Loss: 760.1282
Loss: 758.9980
Loss: 760.2808
Loss: 758.8284
Loss: 759.9775
Loss: 758.7100
Loss: 760.3772
Loss: 758.8898
Loss: 758.7626
Loss: 761.6842
Loss: 760.8819
Loss: 759.4440
Loss: 760.4081
Loss: 759.8478
Loss: 758.4083
Loss: 761.3316
Loss: 760.0054
Loss: 760.4619
Loss: 760.2502
Loss: 759.9586
Loss: 758.6992
Loss: 758.6745
Loss: 756.6503
Loss: 759.1357
Loss: 758.2160
Loss: 762.2259
Loss: 758.0843
Loss: 759.9832
Loss: 759.8146
Loss: 762.2816
Loss: 757.

755.9876741943359