In [None]:
import pickle
graph = pickle.load("path/to/graph.pkl")
model = pickle.load("path/to/model.pkl")

In [None]:
import torch

def unqiue_in_tracks(tracks, type):
    u = set()
    for t in tracks:
        v = t.get(type)
        if v:
            u.add(v)
    return u

def add_playlist(graph, playlist):
    # Extract playlist id
    assert "playlist_id" in playlist
    playlist_id = playlist["playlist_id"]

    # Expected playlist parameters
    key_defaults = {
        "num_followers": 0,
        "collaborative": False,
        "num_edits": 0,
        "tracks": []
    }

    # Set default values for missing parameters
    playlist = {**key_defaults, **playlist}

    # Count unique elements
    playlist["num_tracks"] = len(playlist["tracks"])
    playlist["num_artists"] = len(unqiue_in_tracks(playlist["tracks"], "artist_id"))
    playlist["num_albums"] = len(unqiue_in_tracks(playlist["tracks"], "album_id"))

    # Sum durations
    playlist["duration_ms"] = sum([t.get("duration_ms", 0) for t in playlist["tracks"]])

    # Extract playlist embedding
    playlist_embedding = [playlist["num_followers"], playlist["collaborative"], playlist["num_albums"], playlist["num_tracks"], playlist["num_edits"], playlist["duration_ms"], playlist["num_artists"]]
    playlist_embedding = torch.LongTensor(playlist_embedding).reshape(1, -1)

    # Add playlist to graph
    graph["playlist"].x = torch.cat((graph["playlist"].x, playlist_embedding), dim=0)

    # Add track to playlist connections
    for track in playlist["tracks"]:
        assert "track_id" in track
        track_id = track["id"]
        graph["track", "contains", "playlist"].edge_index = torch.cat((graph["track", "contains", "playlist"].edge_index, torch.LongTensor([(track_id, playlist_id)])), dim=1)

        assert "artist_id" in track
        artist_id = track["artist_id"]
        graph["track", "includes", "album"].edge_index = torch.cat((graph["track", "includes", "album"].edge_index, torch.LongTensor([(track_id, artist_id)])), dim=1)

        assert "album_id" in track
        album_id = track["album_id"]
        graph["track", "includes", "album"].edge_index = torch.cat((graph["track", "includes", "album"].edge_index, torch.LongTensor([(track_id, album_id)])), dim=1)

    # Construct index of all track connections to new playlist
    new_playlist_tracks_index = torch.LongTensor([playlist_id] * len(graph["track"].x))
    graph["track", "contains", "playlist"].edge_label_index = new_playlist_tracks_index

In [None]:
playlist = {
    "playlist_id": 0,
    "tracks": [
        {
            "track_id": 0,
            "artist_id": 10,
            "album_id": 100,
            "duration_ms": 1000
        }, 
        {
            "track_id": 1,
            "artist_id": 11,
            "album_id": 101,
            "duration_ms": 1000
        }
    ]
}
# add playlist and set track connections for prediction
add_playlist(graph, playlist)

In [None]:
# predict track connections
with torch.no_grad():
    pred = model(graph)
    
    # get top 10 predictions
    top_10 = pred.topk(10, dim=1).indices[0]

    # get track ids
    track_ids = graph["track"].x[top_10].tolist()

In [None]:
print(track_ids)