# Package Imports

In [1]:
#from pymongo import MongoClient
from collections import defaultdict, Counter
from tqdm.auto import tqdm
import networkx as nx
import random
import math
import pickle
import torch
import pandas as pd
import torch_geometric
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.utils.convert import to_networkx, from_networkx
from torch_geometric.utils import to_undirected, is_undirected
import numpy as np
from networkx import to_dict_of_dicts
from torch_geometric.loader import NeighborLoader
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.__version__)
print(torch_geometric.__version__)

2.2.2+cu118
2.5.3


# Load Data

## Load User Track Graph

In [3]:
import pickle

# First, extract the contents of dataset.rar to a folder

# Then, load the data from the dataset.pickle file
with open('data/MRecury_data/dataset.pickle', 'rb') as f:
    dataset = pickle.load(f)

# access the different parts of the dataset:
full_graph = dataset['full']
train_graph = dataset['train']
test_graph = dataset['test']
users_mapping = dataset['users']
#artist_tracks_mapping = dataset['artist-tracks']
# Load New Mapping: 
artist_tracks_mapping = pd.read_csv('data/new_artist_tracks_mapping_df.csv')


# Accessing nodes and edges of the graphs:
# For example, to access nodes and edges of the full graph:
full_nodes = full_graph.nodes()
full_edges = full_graph.edges()

train_nodes = train_graph.nodes()
train_edges = train_graph.edges()



In [4]:
def create_dataframe_from_graph(graph):
    # Initialize lists to store extracted information
    user_ids = []
    song_ids = []
    scrobbles = []
    positions = []
    dates = []

    # Iterate over users
    for user_id in tqdm(range(3307)):
        if user_id in graph:
            # Iterate over the songs
            for song_id, songs_info in graph[user_id].items():
                user_ids.append(user_id)
                song_ids.append(song_id)
                scrobbles.append(songs_info['scrobbles'])
                positions.append(songs_info['pos'])
                dates.append(songs_info['date'])
                
    # Create a DataFrame from the lists
    graph_df = pd.DataFrame({
        'User_ID': user_ids,
        'Song_ID': song_ids,
        'Scrobbles': scrobbles,
        'Position': positions,
        'Date': dates
    })
    
    return graph_df

full_graph_df = create_dataframe_from_graph(full_graph)
train_graph_df = create_dataframe_from_graph(train_graph)
test_graph_df = create_dataframe_from_graph(test_graph)

100%|██████████| 3307/3307 [00:01<00:00, 2226.33it/s]
100%|██████████| 3307/3307 [00:01<00:00, 2755.49it/s]
100%|██████████| 3307/3307 [00:00<00:00, 18159.01it/s]


In [5]:
full_graph_df

Unnamed: 0,User_ID,Song_ID,Scrobbles,Position,Date
0,0,3307,370,inf,
1,0,3308,357,inf,
2,0,3309,349,inf,
3,0,3310,347,inf,
4,0,3311,346,inf,
...,...,...,...,...,...
3018204,3306,3815,1,inf,
3018205,3306,14156,1,800.0,"Saturday 23 Jan 2010, 3:01pm"
3018206,3306,12191,1,inf,
3018207,3306,173432,1,inf,


In [8]:
# Check max Song_ID in the full dataset
max_song_id = full_graph_df['Song_ID'].max()
print(f"The maximum Song_ID in the full dataset is: {max_song_id}")

# Check max Song_ID in the training set
max_song_id_train = train_graph_df['Song_ID'].max()
print(f"The maximum Song_ID in the training set is: {max_song_id_train}")

# Check max Song_ID in the test set
max_song_id_test = test_graph_df['Song_ID'].max()
print(f"The maximum Song_ID in the test set is: {max_song_id_test}")

# Verify if the maximum Song_ID is consistent across all datasets
if max_song_id == max_song_id_train == max_song_id_test:
    print("The maximum Song_ID is consistent across all datasets.")
else:
    print("Warning: The maximum Song_ID is not consistent across all datasets.")


The maximum Song_ID in the full dataset is: 255320
The maximum Song_ID in the training set is: 255320
The maximum Song_ID in the test set is: 255320
The maximum Song_ID is consistent across all datasets.


In [6]:
train_graph_df

Unnamed: 0,User_ID,Song_ID,Scrobbles,Position,Date
0,0,6235,22,196.0,"Wednesday 22 Sep 2021, 10:41am"
1,0,6346,21,197.0,"Wednesday 22 Sep 2021, 10:37am"
2,0,6460,20,198.0,"Wednesday 22 Sep 2021, 10:33am"
3,0,6347,21,199.0,"Wednesday 22 Sep 2021, 10:30am"
4,0,6117,23,200.0,"Wednesday 22 Sep 2021, 10:26am"
...,...,...,...,...,...
2564903,3306,222754,1,inf,
2564904,3306,8500,1,inf,
2564905,3306,3815,1,inf,
2564906,3306,12191,1,inf,


In [7]:
test_graph_df

Unnamed: 0,User_ID,Song_ID,Scrobbles,Position,Date
0,0,188713,8,130,"Thursday 23 Sep 2021, 9:31am"
1,0,8573,10,131,"Thursday 23 Sep 2021, 9:22am"
2,0,4256,53,132,"Thursday 23 Sep 2021, 9:19am"
3,0,4521,47,133,"Thursday 23 Sep 2021, 9:15am"
4,0,4522,47,134,"Thursday 23 Sep 2021, 9:12am"
...,...,...,...,...,...
453296,3306,202755,1,854,"Tuesday 19 Jan 2010, 10:17pm"
453297,3306,9790,1,857,"Tuesday 19 Jan 2010, 10:05pm"
453298,3306,169252,1,860,"Tuesday 19 Jan 2010, 9:51pm"
453299,3306,181742,1,862,"Tuesday 19 Jan 2010, 9:43pm"


## Load Social Graph

In [5]:
def load_social(file_users, file_edges, users_ids):
    df_users = pd.read_csv(file_users, sep='\t', names=['id', 'user'])
    df_edges = pd.read_csv(file_edges, sep=' ', names=['origin', 'destination'])
    old_new = {}
    for _, r in tqdm(df_users.iterrows(), total=len(df_users)):
        if r['user'] in users_ids:
            old_new[r['id']] = users_ids[r['user']]
    social_graph = nx.DiGraph()
    social_graph.add_nodes_from(old_new.values())
    for _, r in tqdm(df_edges.iterrows(), total=len(df_edges)):
        if r['origin'] in old_new and r['destination'] in old_new:
            social_graph.add_edge(old_new[r['origin']], old_new[r['destination']])
    return social_graph

social_graph = load_social('data/MRecury_data/lastfm.nodes', 'data/MRecury_data/lastfm.edges', users_mapping)

100%|██████████| 136420/136420 [00:01<00:00, 82771.63it/s]
100%|██████████| 1685524/1685524 [00:16<00:00, 101293.56it/s]


# Transformation into PyG Graph

## Graph Data Inspection for Transformation

In [6]:
dataset = train_graph
print(f'Dataset: {dataset}:')
print('======================')

#print(f'Number of graphs: {len(full_graph)}') # seems like the number of graphs is wrong, this is identital with nodes
print(f'Number of nodes: {len(train_nodes)}')
print(f'Number of edges: {len(train_edges)}')
print(f'Number of User Nodes: {len(users_mapping)}') # i know this is the case from my inspection in Notebook 1. Also I could inspect the train_graph_df if needed.
print(f'Number of Tracks Nodes: {(len(train_nodes) - len(users_mapping))} (one too much because of wrong original Data)')
print(f'Number of unique Songs: {len(artist_tracks_mapping)} - Correct Unique Song Numbers')


#print(f'Number of features: {dataset.num_features}')
#print(f'Number of classes: {dataset.num_classes}')

Dataset: Graph with 255321 nodes and 2564908 edges:
Number of nodes: 255321
Number of edges: 2564908
Number of User Nodes: 3307
Number of Tracks Nodes: 252014 (one too much because of wrong original Data)
Number of unique Songs: 252013 - Correct Unique Song Numbers


## Train Graph Preprocessing (Skip on Rerun)
Cleaning the Train Graph of missing songs and resetting Indices of each Node for Pyg Graph Init

Can be Skipped on Re-Runs for the same dataset (eg. Train)

In [7]:
## Initialize Mapping

In [8]:
artist_tracks_mapping = pd.read_csv('data/new_artist_tracks_mapping_df.csv')
artist_tracks_mapping = artist_tracks_mapping.rename(columns = {'Song_Node_ID':'Song_ID'})
artist_tracks_mapping

Unnamed: 0,Artist,Song_Name,Song_ID
0,Black Kids,I'm Not Gonna Teach Your Boyfriend How to Danc...,3307
1,Black Kids,Hit The Heartbrakes,3308
2,Black Kids,I've Underestimated My Charm (Again),3309
3,Black Kids,Partie Traumatic,3310
4,Black Kids,I'm Making Eyes at You,3311
...,...,...,...
252008,Jamie Lancaster,Boys Don't Cry,255203
252009,Sleeperstar,I Was Wrong,255208
252010,Anthony Naples,Mad Disrespect,255228
252011,Irene,Stardust,255253


In [9]:
# Merge the train_graph_df with artist_tracks_mapping on the "Song_ID" column
train_graph_df = pd.merge(train_graph_df, artist_tracks_mapping, on="Song_ID", how="left")

# Display the structure of the merged DataFrame
train_graph_df

Unnamed: 0,User_ID,Song_ID,Scrobbles,Position,Date,Artist,Song_Name
0,0,6235,22,196.0,"Wednesday 22 Sep 2021, 10:41am",Boniface,Making Peace With Suburbia
1,0,6346,21,197.0,"Wednesday 22 Sep 2021, 10:37am",Boniface,Stay Home
2,0,6460,20,198.0,"Wednesday 22 Sep 2021, 10:33am",Boniface,It's A Joke
3,0,6347,21,199.0,"Wednesday 22 Sep 2021, 10:30am",Boniface,Wake Me Back Up
4,0,6117,23,200.0,"Wednesday 22 Sep 2021, 10:26am",Boniface,Ghosts
...,...,...,...,...,...,...,...
2564903,3306,222754,1,inf,,Killswitch Engage,Never Again
2564904,3306,8500,1,inf,,Kings of Leon,The Runner
2564905,3306,3815,1,inf,,Kings of Leon,Use Somebody
2564906,3306,12191,1,inf,,Kiss,Love Gun


In [10]:
## length of the Train Graph df before (and after merge). 14 is wrong, should be 13, since one song was double in the original graph and mapping. Hence we have to remove this song
len(train_graph_df["Song_ID"].unique())

252014

### check for NaNs in Songname or ID
Our 1 song " 	Artist 	Song_Name 	Song_Node_ID
74634 	Banda UÓ 	Cavalo de Fogo 	ID: '68691" that was correctly excluded in the new mapping (and appeared twice with different IDs in the old mapping) is still being joined here because
it exists in the original Train Df (that we take as the merge input). Hence it will not get left joined with any New Mapping data, and we have to delete it.

In [11]:
train_graph_df["Song_ID"][train_graph_df["Song_ID"].isnull()]

Series([], Name: Song_ID, dtype: int64)

In [12]:
train_graph_df[train_graph_df["Song_Name"].isnull()]

Unnamed: 0,User_ID,Song_ID,Scrobbles,Position,Date,Artist,Song_Name
2157727,2614,68691,22,inf,,,


In [13]:
train_graph_df.loc[2157727]

User_ID       2614
Song_ID      68691
Scrobbles       22
Position       inf
Date          None
Artist         NaN
Song_Name      NaN
Name: 2157727, dtype: object

In [14]:
## identify this song by ID:
#train_graph_df[train_graph_df["Song_ID"] == 68691]

###  remove empty songs / duplicate Song_ID

In [15]:
train_graph_df = train_graph_df.drop(2157727).reset_index()
train_graph_df

Unnamed: 0,index,User_ID,Song_ID,Scrobbles,Position,Date,Artist,Song_Name
0,0,0,6235,22,196.0,"Wednesday 22 Sep 2021, 10:41am",Boniface,Making Peace With Suburbia
1,1,0,6346,21,197.0,"Wednesday 22 Sep 2021, 10:37am",Boniface,Stay Home
2,2,0,6460,20,198.0,"Wednesday 22 Sep 2021, 10:33am",Boniface,It's A Joke
3,3,0,6347,21,199.0,"Wednesday 22 Sep 2021, 10:30am",Boniface,Wake Me Back Up
4,4,0,6117,23,200.0,"Wednesday 22 Sep 2021, 10:26am",Boniface,Ghosts
...,...,...,...,...,...,...,...,...
2564902,2564903,3306,222754,1,inf,,Killswitch Engage,Never Again
2564903,2564904,3306,8500,1,inf,,Kings of Leon,The Runner
2564904,2564905,3306,3815,1,inf,,Kings of Leon,Use Somebody
2564905,2564906,3306,12191,1,inf,,Kiss,Love Gun


In [16]:
## no more empty song names
train_graph_df["Song_Name"][train_graph_df["Song_Name"].isnull()]

Series([], Name: Song_Name, dtype: object)

### Reset Song Indices of Song IDS:



In [17]:
New_Song_ID_dict = {}
new_ID_Counter = 0
missing_id = 68691

for song_id in np.sort(train_graph_df['Song_ID'].unique()):
    #print(song_id)
    if song_id != missing_id:
        
        #new_ID = new_ID_Counter
        New_Song_ID_dict[song_id] = new_ID_Counter
        new_ID_Counter += 1
        #print(New_Song_ID)

#New_Song_ID_dict


In [18]:
# Apply the mapping to the dataset
train_graph_df['Song_ID'] = train_graph_df['Song_ID'].map(New_Song_ID_dict)

In [19]:
train_graph_df

Unnamed: 0,index,User_ID,Song_ID,Scrobbles,Position,Date,Artist,Song_Name
0,0,0,2928,22,196.0,"Wednesday 22 Sep 2021, 10:41am",Boniface,Making Peace With Suburbia
1,1,0,3039,21,197.0,"Wednesday 22 Sep 2021, 10:37am",Boniface,Stay Home
2,2,0,3153,20,198.0,"Wednesday 22 Sep 2021, 10:33am",Boniface,It's A Joke
3,3,0,3040,21,199.0,"Wednesday 22 Sep 2021, 10:30am",Boniface,Wake Me Back Up
4,4,0,2810,23,200.0,"Wednesday 22 Sep 2021, 10:26am",Boniface,Ghosts
...,...,...,...,...,...,...,...,...
2564902,2564903,3306,219446,1,inf,,Killswitch Engage,Never Again
2564903,2564904,3306,5193,1,inf,,Kings of Leon,The Runner
2564904,2564905,3306,508,1,inf,,Kings of Leon,Use Somebody
2564905,2564906,3306,8884,1,inf,,Kiss,Love Gun


In [20]:
np.sort(train_graph_df['Song_ID'].unique())

array([     0,      1,      2, ..., 252010, 252011, 252012], dtype=int64)

In [21]:
# Iterate through the New_Song_ID column and check if IDs run consecutively from 0
consecutive = True
prev_id = -1

for new_id in np.sort(train_graph_df['Song_ID'].unique()):
    if new_id != prev_id + 1:
        consecutive = False
        break
    prev_id = new_id

# Print the result
if consecutive:
    print("New Song IDs run consecutively from 0.")
else:
    print("New Song IDs do not run consecutively from 0.")


New Song IDs run consecutively from 0.


In [22]:
train_graph_df["Song_ID"].unique().max()

252012

### check number of songs
ALL SONG from the FULL DATA are still "somewhere" in the train data. BUT not every user that had a "Listen to" relation to a song exists.
So eg. User 0 could have listened to track 3307 in the Full Graph, but in the Train data this connection is not seen. But for user 1 the connection remains in the training data, so this way we can see the song 3307 at least once in the training data. This means each song is at least listened to by 1 user in the training data. 

In [23]:
unique_song_ids= train_graph_df["Song_ID"].unique()
print("Length:", len(unique_song_ids))
print("Max Value:", train_graph_df["Song_ID"].unique().max())

Length: 252013
Max Value: 252012


### Save / Load Clean TrainGraph DF

In [24]:
import os

# Check if the file already exists
if not os.path.exists('data/train_graph_df_clean.csv'):
    # Save the merged DataFrame to a CSV file
    train_graph_df.to_csv('data/train_graph_df_clean.csv', index=False)
else:
    # Load the existing CSV file
    train_graph_df = pd.read_csv('data/train_graph_df_clean.csv')

## Import Clean TrainGraph DF

In [26]:
train_graph_df[train_graph_df["Song_Name"] == "Oblique City"]

Unnamed: 0,index,User_ID,Song_ID,Scrobbles,Position,Date,Artist,Song_Name
38940,38940,5,185405,2,inf,,Phoenix,Oblique City
65536,65536,8,185405,2,inf,,Phoenix,Oblique City
275193,275193,169,185405,10,inf,,Phoenix,Oblique City
355907,355907,267,185405,69,inf,,Phoenix,Oblique City
949778,949778,1038,185405,30,inf,,Phoenix,Oblique City
1092864,1092864,1205,185405,25,inf,,Phoenix,Oblique City
1364441,1364441,1594,185405,8,inf,,Phoenix,Oblique City
1523905,1523905,1672,185405,1,inf,,Phoenix,Oblique City
1581941,1581941,1678,185405,2,inf,,Phoenix,Oblique City
1961853,1961853,2290,185405,103,inf,,Phoenix,Oblique City


In [27]:
train_graph_df

Unnamed: 0,index,User_ID,Song_ID,Scrobbles,Position,Date,Artist,Song_Name
0,0,0,2928,22,196.0,"Wednesday 22 Sep 2021, 10:41am",Boniface,Making Peace With Suburbia
1,1,0,3039,21,197.0,"Wednesday 22 Sep 2021, 10:37am",Boniface,Stay Home
2,2,0,3153,20,198.0,"Wednesday 22 Sep 2021, 10:33am",Boniface,It's A Joke
3,3,0,3040,21,199.0,"Wednesday 22 Sep 2021, 10:30am",Boniface,Wake Me Back Up
4,4,0,2810,23,200.0,"Wednesday 22 Sep 2021, 10:26am",Boniface,Ghosts
...,...,...,...,...,...,...,...,...
2564902,2564903,3306,219446,1,inf,,Killswitch Engage,Never Again
2564903,2564904,3306,5193,1,inf,,Kings of Leon,The Runner
2564904,2564905,3306,508,1,inf,,Kings of Leon,Use Somebody
2564905,2564906,3306,8884,1,inf,,Kiss,Love Gun


## Node Data Setup

### Select All Unique Users

In [28]:
user_ids = len(train_graph_df["User_ID"].unique())
user_ids

3307

### Select All Unique Song Nodes

In [29]:
song_ids = len(train_graph_df["Song_ID"].unique())
song_ids

252013

### Select All Unique Artists

In [30]:
artist_ids = len(train_graph_df["Artist"].unique())
artist_ids

28120

In [31]:
unique_artists = train_graph_df["Artist"].unique()

In [32]:
# Initialize the starting node ID for artists
#reset this to 0! Important because PyG resets indices for each node type on initialization, and then my edge mapping will be wrong if it doesnt start from 0
artist_node_id = 0

# Dictionary to store node IDs for artists
artist_nodes_dic = {}
artist_nodes = []

# Iterate over each artist and assign node IDs
for artist in unique_artists:

    artist_nodes_dic[artist] = artist_node_id
    
    artist_nodes.append(artist_node_id)
    artist_node_id += 1



In [None]:
artist_nodes_dic

In [34]:
len(artist_nodes)

28120

## Edge Indices Setup

### User-Tracks Edges
Collect Edge Information from User-Tracks

#### Reset Song_ID Index

##### With edge attributes as List

In [35]:
unique_user_ids = len(train_graph_df['User_ID'].unique())
unique_user_ids

3307

#### Create User-Track Edge Indices & Attributes

In [36]:
def create_user_track_edge_index_and_attributes(graph_df):
    
    """
    Create edge index and attributes from a graph.

    Args:
    - graph as NetworkX Graph Object

    Returns:
    - user_song_edge_index (list): List of edges represented by node indices.
    - user_song_edge_scrobbel_attributes (list): List of scrobble attributes for each edge.
    """

    # Initialize index and dictionary
    user_song_edge_index = []
    user_song_edge_scrobbel_attributes = []
    
    for user_id, song_info in graph_df[["User_ID", "Song_ID", "Scrobbles"]].groupby("User_ID"):
        
        scrobbles_per_user_node = []
        
        song_ids = song_info['Song_ID']#.tolist()
        scrobbles = song_info['Scrobbles']
        #song_infos = song_info[['Song_ID', 'Scrobbles']]#
        #print(song_infos)
        #iterate over each song node and add its ID
        for song_id in song_ids:
            # get user and id information
            #print(song_id)
            user_id = user_id
            song_id = song_id

            ## for debudding wrong nodeID error which caused the problem in the LNH Sampler. 13 instead of 12
            if song_id == 252013:
                print(song_id, user_id)
            
            #store as edge tuple
            current_edge_directed = [user_id, song_id]
    
            #append to edge list
            user_song_edge_index.append(current_edge_directed)
            
        for scrobble in scrobbles:
            # Extract edge attributes and append to list
            scrobbles_per_song = scrobble
            scrobbles_per_user_node.append(scrobbles_per_song)
            
        user_song_edge_scrobbel_attributes.extend(scrobbles_per_user_node)
    return user_song_edge_index, user_song_edge_scrobbel_attributes
        
    


In [37]:
user_song_edge_index, user_song_edge_scrobbel_attributes = create_user_track_edge_index_and_attributes(train_graph_df)

In [38]:
print(len(user_song_edge_index), len(user_song_edge_scrobbel_attributes))

2564907 2564907


In [39]:
# Convert the list of edge attributes to a tensor
user_song_edge_attr_tensor = torch.tensor(user_song_edge_scrobbel_attributes, dtype=torch.long)
user_song_edge_attr_tensor.t().size()

torch.Size([2564907])

In [40]:
# Convert the list of edge index to a tensor
user_song_edge_index = torch.tensor(user_song_edge_index , dtype=torch.long)
user_song_edge_index.t().size()

torch.Size([2, 2564907])

### User User Edges
Collect User User Edge Information

#### User User Edge Index

In [41]:
def create_user_user_edge_index(social_graph):
    """
    Create edge index for user-user relationships in a social graph.

    Args:
    - social_graph: NetworkX Graph Object

    Returns:
    - user_user_edge_index (list): List of edges represented by node indices.
     One way - Directed Only. (Will be made undirected withing PyG
    """
    #initialize index
    user_user_edge_index = []
    
    # Iterate over all user nodes
    for user_node in range(3307):
        #print(type(user_node))
        #print(user_node)
           
        # iterate over all edges of each user node in the social graph
        for key, value in social_graph[user_node].items():
            ## add edges twice for undirection - Not necessary since this will be done in PyG
            current_edge_directed = [user_node, key]
            current_edge_undirected = [key, user_node]
            user_user_edge_index.append(current_edge_directed)
            #user_user_edge_index.append(current_edge_undirected)
    return user_user_edge_index
    
user_user_edge_index = create_user_user_edge_index(social_graph)

In [42]:
user_user_edge_index = torch.tensor(user_user_edge_index , dtype=torch.long)
user_user_edge_index.t().size()

torch.Size([2, 142919])

###  Artist-Track Edges

#### Creating a Artist-Track Dictionary to feed into the edge Data

In [None]:
artist_nodes_dic

In [44]:
train_graph_df[["Artist","Song_ID","Song_Name"]]

Unnamed: 0,Artist,Song_ID,Song_Name
0,Boniface,2928,Making Peace With Suburbia
1,Boniface,3039,Stay Home
2,Boniface,3153,It's A Joke
3,Boniface,3040,Wake Me Back Up
4,Boniface,2810,Ghosts
...,...,...,...
2564902,Killswitch Engage,219446,Never Again
2564903,Kings of Leon,5193,The Runner
2564904,Kings of Leon,508,Use Somebody
2564905,Kiss,8884,Love Gun


In [45]:
def artist_to_song(df):
    # Initialize an empty dictionary to store the mapping of artists to songs
    artist_to_songs = {}
    
    # Iterate through the DataFrame rows
    for index, row in df.iterrows():
        artist = row['Artist']
        song_name = row['Song_Name']
        song_id = row['Song_ID']
        
        # Check if the artist is already in the dictionary
        if artist in artist_to_songs:
            # Add the song to the nested dictionary
            artist_to_songs[artist][song_id] = song_name
        else:
            # Create a new nested dictionary with the song and add it to the dictionary
            artist_to_songs[artist] = {song_id: song_name}


    return artist_to_songs

In [46]:
artist_to_songs_dic = artist_to_song(train_graph_df)


In [None]:
artist_to_songs_dic

In [48]:
def create_song_artist_edge_index(artist_tracks_mapping_dic):
    # init artist and songs dict from mapping df
    #artist_tracks_mapping_dic = dict(zip(artist_tracks_mapping['Artist'], artist_tracks_mapping['Song_ID']))
    
    # Initialize variables to store edge index and attributes
    artist_song_edge_index = []

     
    # Iterate over each artist and their songs
    for artist, songs in artist_tracks_mapping_dic.items():
        # Get the node ID of the current artist
        artist_node_id = artist_nodes_dic[artist]
        
        #print(artist, songs)
        #print(artist)
        #print(songs)
        
        # Iterate over each song and its node ID
        for song_node_id, song_name  in songs.items():
            
            # Create directed edges from artist to song
            #print(artist, song_node_id)
            #print(artist, song_node_id, song_name)
            artist_song_edge_index.append([artist_node_id, song_node_id])#, song_name, artist])
            
            # Optional: Create directed edges from song to artist
            #artist_song_edge_index.append([song_node_id, artist_node_id])
    return artist_song_edge_index
            
artist_song_edge_index = create_song_artist_edge_index(artist_to_songs_dic)
#artist_song_edge_index

In [49]:
artist_song_edge_index = torch.tensor(artist_song_edge_index , dtype=torch.long)
artist_song_edge_index.t().size()

torch.Size([2, 252013])

In [50]:
artist_song_edge_index

tensor([[     0,   2928],
        [     0,   3039],
        [     0,   3153],
        ...,
        [ 28117, 117102],
        [ 28118, 115993],
        [ 28119,  98291]])

## Check All Edges
Edges are undirected:

In [51]:
print(len(user_song_edge_index))
print(len(user_user_edge_index))
print(len(artist_song_edge_index))

2564907
142919
252013


In [52]:
print(f"The Edge {user_song_edge_index} is undirected: {is_undirected(user_song_edge_index)}.")
print(f"The Edge {user_user_edge_index} is undirected: {is_undirected(user_user_edge_index)}.")
print(f"The Edge {artist_song_edge_index} is undirected: {is_undirected(artist_song_edge_index)}.")

The Edge tensor([[     0,   2928],
        [     0,   3039],
        [     0,   3153],
        ...,
        [  3306,    508],
        [  3306,   8884],
        [  3306, 170124]]) is undirected: False.
The Edge tensor([[   0,  763],
        [   0, 1435],
        [   0,  122],
        ...,
        [3306,  326],
        [3306,  926],
        [3306,  700]]) is undirected: False.
The Edge tensor([[     0,   2928],
        [     0,   3039],
        [     0,   3153],
        ...,
        [ 28117, 117102],
        [ 28118, 115993],
        [ 28119,  98291]]) is undirected: False.


# Traing Graph
## Prepare Data 
For Graph Creation. Full Graph needs information like Song Tags and Audio features, which are not present in base graph
Could also be called "Rich" Graph

## load tags and audio features

## Clean Audio df from missing ID

In [53]:
# Load Track Tag
audio_df = pd.read_csv('data/final_audio_df.csv')
train_graph_df = pd.read_csv('data/train_graph_df_clean.csv')


In [54]:
# Check if any Song_ID is missing in audio_df
audio_missing_ids = set(range(audio_df['Song_ID'].min(), audio_df['Song_ID'].max() + 1)) - set(audio_df['Song_ID'])
print(f"Missing Song_IDs in audio_df: {audio_missing_ids}")

# Check if any Song_ID is missing in train_graph_df
train_graph_missing_ids = set(range(train_graph_df['Song_ID'].min(), train_graph_df['Song_ID'].max() + 1)) - set(train_graph_df['Song_ID'])
print(f"Missing Song_IDs in train_graph_df: {train_graph_missing_ids}")

# Remove the missing Song_IDs from audio_df and reset its other IDS starting from 0, so they are consecutive
print("Removing missing Song_IDs from audio_df and resetting its other IDS starting from 0, so they are consecutive...")
audio_df = audio_df.sort_values('Song_ID').reset_index(drop=True)
audio_df['Song_ID'] = range(len(audio_df))

# Check if any Song_ID is missing in audio_df after resetting
audio_missing_ids = set(range(audio_df['Song_ID'].min(), audio_df['Song_ID'].max() + 1)) - set(audio_df['Song_ID'])
print(f"Missing Song_IDs in audio_df: {audio_missing_ids}")

# Check if any Song_ID is missing in train_graph_df after resetting
train_graph_missing_ids = set(range(train_graph_df['Song_ID'].min(), train_graph_df['Song_ID'].max() + 1)) - set(train_graph_df['Song_ID'])
print(f"Missing Song_IDs in train_graph_df: {train_graph_missing_ids}")

# Print the number of unique Song_IDs in both dfs
print(f"Number of unique Song_IDs in audio_df: {audio_df['Song_ID'].nunique()}")
print(f"Number of unique Song_IDs in train_graph_df: {train_graph_df['Song_ID'].nunique()}")


Missing Song_IDs in audio_df: {68691}
Missing Song_IDs in train_graph_df: set()
Removing missing Song_IDs from audio_df and resetting its other IDS starting from 0, so they are consecutive...
Missing Song_IDs in audio_df: set()
Missing Song_IDs in train_graph_df: set()
Number of unique Song_IDs in audio_df: 252013
Number of unique Song_IDs in train_graph_df: 252013


## Process Tags
Remove NaN tags, and make a dict that assigns the tags correctly to their according SongIDs Songs


In [55]:
# show all values were tags are nan in "Song Tags"
audio_df[audio_df['Song_Tags'].isnull()]

# convert these values to a string with "Unknown"
audio_df.loc[audio_df['Song_Tags'].isnull(), 'Song_Tags'] = 'Unknown'

In [56]:
# Sort the DataFrame by Song_ID in ascending order
audio_df = audio_df.sort_values(by='Song_ID', ascending=True)

# Replace "None" with a special token
audio_df['Song_Tags'] = audio_df['Song_Tags'].str.replace('None', 'unknown')

# Ensure all entries in 'Song_Tags' are strings
audio_df['Song_Tags'] = audio_df['Song_Tags'].astype(str)

# Extract and process tags
tags = audio_df['Song_Tags'].str.split(',')

# Flatten the list of tags and count unique tags
all_tags = [tag for sublist in tags for tag in sublist]
unique_tags = set(all_tags)
print(f"Number of unique tags: {len(unique_tags)}")

Number of unique tags: 29103


In [57]:
from sklearn.preprocessing import LabelEncoder

audio_df = audio_df.sort_values(by='Song_ID', ascending=True)

# Extract and process tags
tags = audio_df['Song_Tags'].str.split(',')

# Flatten the list of tags and create a mapping from tags to indices
all_tags = [tag for sublist in tags for tag in sublist]
label_encoder = LabelEncoder()
label_encoder.fit(all_tags)
tag_to_index = {tag: idx for idx, tag in enumerate(label_encoder.classes_)}

In [58]:
# Convert tags to indices
tags_indices = tags.apply(lambda x: [tag_to_index[tag] for tag in x])

# Pad the sequences to ensure they have the same length
from torch.nn.utils.rnn import pad_sequence

tags_indices_padded = pad_sequence([torch.tensor(t) for t in tags_indices], batch_first=True, padding_value=-1)

In [59]:
tags_indices_tensor = tags_indices_padded

In [60]:
tags_indices_tensor.shape

torch.Size([252013, 3])

In [61]:
from collections import Counter

# Create a dictionary to store tags for each artist
artist_tags = {}

# Iterate over each song and its tags, ensuring Song_ID ascending order is kept
for idx, row in audio_df.sort_values(by='Song_ID', ascending=True).iterrows():
    artist = row['Artist']
    song_tags = row['Song_Tags'].split(',')
    
    if artist not in artist_tags:
        artist_tags[artist] = []
    
    artist_tags[artist].extend(song_tags)

# For each artist, find the top 3 most common tags
artist_top_tags = {}
for artist, tags in artist_tags.items():
    most_common_tags = [tag for tag, count in Counter(tags).most_common(3)]
    artist_top_tags[artist] = most_common_tags

# Convert artist tags to indices
artist_tags_indices = {artist: [tag_to_index[tag] for tag in tags] for artist, tags in artist_top_tags.items()}

# Create a tensor for artist tags, padding sequences to ensure they have the same length
artist_tags_indices_padded = pad_sequence([torch.tensor(tags) for tags in artist_tags_indices.values()], batch_first=True, padding_value=-1)

# Convert to tensor
artist_tags_tensor = artist_tags_indices_padded


In [62]:
artist_tags_tensor.shape

torch.Size([28120, 3])

In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, to_hetero
from torch_geometric.data import HeteroData

# Convert to torch tensors
audio_features_tensor = torch.tensor(audio_df.iloc[:, 5:16].values, dtype=torch.float)
tags_indices_tensor = tags_indices_padded

In [64]:
audio_features_tensor.shape, tags_indices_tensor.shape

(torch.Size([252013, 11]), torch.Size([252013, 3]))

# Create Training Graph
Including Audio features and  Track Tags Data

In [65]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

In [66]:
# Create HeteroData object
data = HeteroData()


################### NODES ###################

# Save node indices:
data["users"].node_id = torch.arange(user_ids)
data["songs"].node_id = torch.arange(song_ids)
data["artists"].node_id = torch.arange(artist_ids)

# Add song nodes with audio and tag features
data['songs'].x_audio = audio_features_tensor
data['songs'].x_tag = tags_indices_tensor

# Add artist nodes with tag features
data['artists'].x_tag = artist_tags_tensor


################### EDGES ###################

# Add Edge Indices and Edge Attributes
data['users', 'listens_to', 'songs'].edge_index = user_song_edge_index.t().contiguous()
data['users', 'is_friends_with', 'users'].edge_index = user_user_edge_index.t().contiguous()
data['artists', 'makes', 'songs'].edge_index = artist_song_edge_index.t().contiguous()

# Add Edge Attributes
data['users', 'listens_to', 'songs'].edge_attr = user_song_edge_attr_tensor

# enable undirected edges for message passing
data = T.ToUndirected(merge = False)(data)


In [67]:
print('========================Nodes==============================')

# Gather some statistics about the graph.
print(f'Number of total nodes: {data.num_nodes}')
print(f'Number of user nodes: {data['users'].num_nodes}')
print(f'Number of song nodes: {data['songs'].num_nodes}')
print(f'Number of artist nodes: {data['artists'].num_nodes}')

print('========================Edges==============================')


print(f'Number of total edges: {data.num_edges}')
print(f'Number of Listening edges: {data['rev_listens_to'].num_edges}')
print(f'Number of Friends edges: {data['rev_is_friends_with'].num_edges}')
print(f'Number of Artist Makes Songs edges: {data['rev_makes'].num_edges}')

print('========================Reverse_Edges==============================')


print(f'Number of total edges: {data.num_edges}')
print(f'Number of Listening edges: {data['listens_to'].num_edges}')
print(f'Number of Friends edges: {data['is_friends_with'].num_edges}')
print(f'Number of Artist Makes Songs edges: {data['makes'].num_edges}')

print('========================Degree==============================')


print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')

print('========================Directed==============================')


print(f'Is undirected: {data.is_undirected()}')

Number of total nodes: 283440
Number of user nodes: 3307
Number of song nodes: 252013
Number of artist nodes: 28120
Number of total edges: 5919678
Number of Listening edges: 2564907
Number of Friends edges: 142919
Number of Artist Makes Songs edges: 252013
Number of total edges: 5919678
Number of Listening edges: 2564907
Number of Friends edges: 142919
Number of Artist Makes Songs edges: 252013
Average node degree: 20.89
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [68]:
import os

# Check if the file exists
file_path = 'data/pyg_data/train_hetero_data_3_nodes_rich.pt'
if not os.path.exists(file_path):
    # Save your HeteroData object if the file does not exist
    torch.save(data, file_path)
else:
    # Load the HeteroData object if the file exists
    data = torch.load(file_path)

# Co-Listening Graph 2 Nodes
For structural and content-based embeddings / community creation

In [70]:
# Create HeteroData object
data = HeteroData()


################### NODES ###################
data["users"].node_id = torch.arange(user_ids)
data["songs"].node_id = torch.arange(song_ids)

################### EDGES ###################

# Add Edge Indices and Edge Attributes
data['users', 'listens_to', 'songs'].edge_index = user_song_edge_index.t().contiguous()
data['users', 'is_friends_with', 'users'].edge_index = user_user_edge_index.t().contiguous()

# enable undirected edges for message passing
data = T.ToUndirected(merge = False)(data)


In [71]:
file_path = 'data/pyg_data/train_hetero_data_2_nodes_co.pt'
if not os.path.exists(file_path):
    print('Saving HeteroData object to file...')
    torch.save(data, file_path)
    print('HeteroData object saved.')
else:
    print('Loading HeteroData object from file...')
    data = torch.load(file_path)
    print('HeteroData object loaded.')



Saving HeteroData object to file...
HeteroData object saved.


# Base Graph 3 Nodes - Shallow:
Build Graph - 3 Node Graph


This graph has no information on node features = missing: audio features, tags and Artist Tags. 

Is for debugging and Plotting Purposes only. Data not needed from this point on.

In [None]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

data = HeteroData()

# Save node indices:
data["users"].node_id = torch.arange(user_ids)
data["songs"].node_id = torch.arange(song_ids)
data["artists"].node_id = torch.arange(artist_ids)

# Add Edge Indices and Edge Attributes
# IT seems Like I will have to input ALL edges here as some kind of torch object as in tutorial 9
data['users', 'listens_to', 'songs'].edge_index = user_song_edge_index.t().contiguous()
data['users', 'is_friends_with', 'users'].edge_index = user_user_edge_index.t().contiguous()
data['artists', 'makes', 'songs'].edge_index = artist_song_edge_index.t().contiguous()

# Correcting the edge attribute assignment
data['users', 'listens_to', 'songs'].edge_attr = user_song_edge_attr_tensor
#data['users', 'listens_to', 'songs'].edge_attr  =
#data['users', 'listens_to', 'songs'].edge_attr  = [5, 1]
#data['users', 'is_friends_with', 'users'].edge_attr  = ... # [2, num_edges_cites]
#data['artists', 'makes', 'songs'].edge_attr  = ... # [2, num_edges_cites]

# enable undirected edges for message passing
data = T.ToUndirected(merge = False)(data)

#### Graph Inspection

In [None]:
print('========================Nodes==============================')

# Gather some statistics about the graph.
print(f'Number of total nodes: {data.num_nodes}')
print(f'Number of user nodes: {data['users'].num_nodes}')
print(f'Number of song nodes: {data['songs'].num_nodes}')
print(f'Number of artist nodes: {data['artists'].num_nodes}')

print('========================Edges==============================')


print(f'Number of total edges: {data.num_edges}')
print(f'Number of Listening edges: {data['rev_listens_to'].num_edges}')
print(f'Number of Friends edges: {data['rev_is_friends_with'].num_edges}')
print(f'Number of Artist Makes Songs edges: {data['rev_makes'].num_edges}')

print('========================Reverse_Edges==============================')


print(f'Number of total edges: {data.num_edges}')
print(f'Number of Listening edges: {data['listens_to'].num_edges}')
print(f'Number of Friends edges: {data['is_friends_with'].num_edges}')
print(f'Number of Artist Makes Songs edges: {data['makes'].num_edges}')

print('========================Degree==============================')


print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')

print('========================Directed==============================')


print(f'Is undirected: {data.is_undirected()}')

Number of total nodes: 283440
Number of user nodes: 3307
Number of song nodes: 252013
Number of artist nodes: 28120
Number of total edges: 5919678
Number of Listening edges: 2564907
Number of Friends edges: 142919
Number of Artist Makes Songs edges: 252013
Number of total edges: 5919678
Number of Listening edges: 2564907
Number of Friends edges: 142919
Number of Artist Makes Songs edges: 252013
Average node degree: 20.89
Has isolated nodes: False
Has self-loops: False
Is undirected: True


#### Sub Graph Selection

In [None]:
# Create Dummy Nodes Features so NetworkX can use them as node attributes
import torch

# Create dummy node features for each node type
num_users = data['users'].num_nodes
num_songs = data['songs'].num_nodes
num_artists = data['artists'].num_nodes

user_features = torch.full((num_users, 1),1)
song_features = torch.full((num_songs, 1), 2)
artist_features = torch.full((num_artists, 1), 3)

# Assign the dummy features to the PyG HeteroData object
data['users'].x = user_features
data['songs'].x = song_features
data['artists'].x = artist_features

In [None]:
nx_graph = to_networkx(data)

In [None]:
len(nx_graph.nodes)

283440

In [None]:
import random
import networkx as nx
from torch_geometric.utils import to_networkx

# Convert PyG graph to NetworkX graph
#nx_graph = to_networkx(data)

# Initialize an empty list to store the selected nodes
selected_nodes = []

# Get the set of unique node types
node_types = set(nx.get_node_attributes(nx_graph, 'type').values())

# Select a random subset of nodes for each node type
for node_type in node_types:
    type_nodes = [n for n, d in nx_graph.nodes(data=True) if d['type'] == node_type]
    #print(selected_nodes)
    num_nodes = len(type_nodes)
    num_to_select = min(150, num_nodes)  # Select up to 2000 nodes, or all nodes if less than 2000
    selected_nodes.extend(random.sample(type_nodes, num_to_select))

# Create the subgraph induced by the selected nodes
subgraph = nx.Graph(nx_graph.subgraph(selected_nodes))

In [None]:
import random
from tqdm import tqdm
import networkx as nx

# Get the set of unique node types
node_types = set(nx.get_node_attributes(nx_graph, 'type').values())

# Initialize an empty list to store the selected nodes
selected_nodes = []

# Set the number of nodes to sample for each type
nodes_per_type = {
    'users': 1,
    'songs': 20,  # Increased number of song nodes
    'artists': 10
}

# Sample nodes for each type
for node_type in tqdm(node_types, desc="Selecting nodes", unit="node_type"):
    type_nodes = [n for n, d in nx_graph.nodes(data=True) if d['type'] == node_type]
    num_nodes = len(type_nodes)
    num_to_select = min(nodes_per_type[node_type], num_nodes)
    selected_nodes.extend(random.sample(type_nodes, num_to_select))

# Create a subgraph with the selected nodes
subgraph = nx.Graph(nx_graph.subgraph(selected_nodes))

# If we want to ensure connectivity, we can add additional nodes to connect the subgraph
connected_components = list(nx.connected_components(subgraph))

if len(connected_components) > 1:
    # Find the largest connected component
    largest_component = max(connected_components, key=len)
    
    # For each smaller component, find a path to the largest component and add the nodes in the path
    for component in connected_components:
        if component != largest_component:
            for node in component:
                path = nx.shortest_path(nx_graph, node, random.choice(list(largest_component)))
                selected_nodes.extend(path)

# Create the final subgraph
subgraph = nx.Graph(nx_graph.subgraph(selected_nodes))

# Print the number of nodes for each type in the subgraph
for node_type in node_types:
    type_count = len([n for n, d in subgraph.nodes(data=True) if d['type'] == node_type])
    print(f"Number of {node_type} nodes: {type_count}")

print(f"Total number of nodes: {subgraph.number_of_nodes()}")
print(f"Total number of edges: {subgraph.number_of_edges()}")

Selecting nodes: 100%|██████████| 3/3 [00:00<00:00, 87.23node_type/s]

Number of songs nodes: 35
Number of artists nodes: 10
Number of users nodes: 36
Total number of nodes: 81
Total number of edges: 177





In [None]:
len(subgraph.nodes)

81

In [None]:
%%time
largest_cc = max(nx.connected_components(subgraph), key=len)
subgraph = nx.Graph(subgraph.subgraph(largest_cc))

CPU times: total: 0 ns
Wall time: 0 ns


In [None]:
## Remove nodes that didnt get an attribute (I dont know why) and delete those, otherwise plotting will get stuck on those that have no attribute

nodes_to_remove = []

for (id, attr) in subgraph.nodes(data=True):
    if 'type' not in attr:
        print(f"Node {id} does not have a 'node_type' attribute.")
        nodes_to_remove.append(id)

for node in nodes_to_remove:
    subgraph.remove_node(node)


In [None]:
print(len([id for id, attr in subgraph.nodes(data=True) if attr['type'] == 'songs']))
print(len([id for id, attr in subgraph.nodes(data=True) if attr['type'] == 'users']))
print(len([id for id, attr in subgraph.nodes(data=True) if attr['type'] == 'artists']))
print(len(subgraph.nodes))
print(len(subgraph.edges))

35
36
10
81
177


#### Visualization

In [None]:
import plotly.graph_objects as go
import networkx as nx

# Assuming you have already defined your graph 'subgraph'
color_map = {'users': 'rgb(255,0,0)', 'songs': 'rgb(255,255,0)', 'artists': 'rgb(0,0,255)'}

# Extracting node types and colors
node_types = nx.get_node_attributes(subgraph, 'type')
node_colors = [color_map[node_types[node]] for node in subgraph.nodes()]

# Creating node trace
node_trace = go.Scatter(
    x=[],
    y=[],
    text=[],
    mode='markers',
    hoverinfo='text',
    marker=dict(
        showscale=False,
        color=node_colors,
        size=12,
        line_width=2))

# Creating edge trace
edge_trace = go.Scatter(
    x=[],
    y=[],
    line=dict(width=0.5, color='#888'),
    hoverinfo='none',
    mode='lines')

# Adding node positions to the trace using Kamada-Kawai layout
pos = nx.kamada_kawai_layout(subgraph)
for node, (x, y) in pos.items():
    node_trace['x'] += tuple([x])
    node_trace['y'] += tuple([y])

# Adding edge positions to the trace
for edge in subgraph.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_trace['x'] += tuple([x0, x1, None])
    edge_trace['y'] += tuple([y0, y1, None])

# Creating figure
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title='<br>Network graph made with Plotly',
                    titlefont_size=10,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=30, l=20, r=50, t=40),
                    annotations=[dict(
                        text="Python code: <a href='https://plotly.com/'>Plotly</a>",
                        showarrow=False,
                        xref="paper", yref="paper",
                        x=0.005, y=-0.002)],
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    width=1200,  # Adjust width of the plot
                    height=1000,  # Adjust height of the plot
                    legend=dict(
                        x=0,
                        y=1,
                        traceorder="normal",
                        font=dict(
                            family="sans-serif",
                            size=12,
                            color="black"
                        ),
                        bgcolor="LightSteelBlue",
                        bordercolor="Black",
                        borderwidth=2,
                        title="Node Types",)
                    ))

# Show plot
fig.show()


In [None]:
import plotly.graph_objects as go
import networkx as nx

# Assuming you have already defined your graph 'subgraph'
color_map = {'users': 'rgb(255,0,0)', 'songs': 'rgb(255,255,0)', 'artists': 'rgb(0,0,255)'}

# Extracting node types and colors
node_types = nx.get_node_attributes(subgraph, 'type')

# Creating separate traces for each node type
traces = []
for node_type in color_map:
    node_x = []
    node_y = []
    for node, data in subgraph.nodes(data=True):
        if data['type'] == node_type:
            x, y = pos[node]
            node_x.append(x)
            node_y.append(y)
    
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        name=node_type.capitalize(),
        marker=dict(
            color=color_map[node_type],
            size=12,
            line_width=2
        ),
        text=[f"Node: {node}<br>Type: {node_type}" for node in subgraph.nodes() if subgraph.nodes[node]['type'] == node_type],
        hoverinfo='text'
    )
    traces.append(node_trace)

# Creating edge trace
edge_x = []
edge_y = []
for edge in subgraph.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.5, color='#888'),
    hoverinfo='none',
    mode='lines',
    name='Edges'
)

# Add edge trace to the beginning of traces list
traces.insert(0, edge_trace)

# Creating figure
fig = go.Figure(data=traces,
                layout=go.Layout(
                    title='<br>Network graph with selectable node types',
                    titlefont_size=16,
                    showlegend=True,
                    hovermode='closest',
                    margin=dict(b=20, l=5, r=5, t=40),
                    annotations=[dict(
                        text="Python code: <a href='https://plotly.com/'>Plotly</a>",
                        showarrow=False,
                        xref="paper", yref="paper",
                        x=0.005, y=-0.002
                    )],
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    width=1200,
                    height=1000,
                    legend=dict(
                        yanchor="top",
                        y=0.99,
                        xanchor="left",
                        x=0.01,
                        bgcolor="rgba(255, 255, 255, 0.5)",
                        bordercolor="Black",
                        borderwidth=1
                    )
                ))

# Update layout to have a toggle for each trace
fig.update_layout(
    updatemenus=[
        dict(
            type="buttons",
            direction="right",
            active=0,
            x=0.57,
            y=1.2,
            buttons=list([
                dict(label="All",
                     method="update",
                     args=[{"visible": [True] * len(traces)},
                           {"title": "All node types"}]),
                dict(label="Users",
                     method="update",
                     args=[{"visible": [True, True, False, False]},
                           {"title": "Users only"}]),
                dict(label="Songs",
                     method="update",
                     args=[{"visible": [True, False, True, False]},
                           {"title": "Songs only"}]),
                dict(label="Artists",
                     method="update",
                     args=[{"visible": [True, False, False, True]},
                           {"title": "Artists only"}]),
            ]),
        )
    ]
)

# Show plot
fig.show()

In [None]:
from pyvis.network import Network
import networkx as nx
import random
from IPython.display import IFrame, display, HTML


# Create a Pyvis network
net = Network(height="750px", width="100%", bgcolor="#222222", font_color="white", notebook=True)

# Set some display options
net.toggle_hide_edges_on_drag(True)
net.barnes_hut()
net.show_buttons(filter_=['physics'])

# Define node colors and sizes
color_map = {'users': '#ff6b6b', 'songs': '#feca57', 'artists': '#48dbfb'}
size_map = {'users': 30, 'songs': 25, 'artists': 35}  # Increased sizes

# Define edge colors
edge_color_map = {
    ('users', 'songs'): '#ff9ff3',  # User-Song connections
    ('artists', 'songs'): '#54a0ff',  # Artist-Song connections
    ('users', 'users'): '#5f27cd'  # User-User connections
}

# Add nodes to the network
for node, data in subgraph.nodes(data=True):
    node_type = data['type']
    net.add_node(node, 
                 label=f"{node_type}: {node}", 
                 color=color_map[node_type], 
                 size=size_map[node_type],
                 title=f"Node: {node}<br>Type: {node_type}")

# Add edges to the network
for edge in subgraph.edges(data=True):
    source_type = subgraph.nodes[edge[0]]['type']
    target_type = subgraph.nodes[edge[1]]['type']
    edge_type = tuple(sorted([source_type, target_type]))
    color = edge_color_map.get(edge_type, '#888888')  # Default color if not in map
    net.add_edge(edge[0], edge[1], color=color)

# Generate the HTML file
net.save_graph("data/beautiful_graph.html")

# Create a legend
legend_html = """
<div style="position: absolute; top: 10px; left: 10px; background-color: rgba(255,255,255,0.7); padding: 10px; border-radius: 5px;">
    <h3 style="margin-top: 0;">Legend</h3>
    <h4>Node Types:</h4>
    <ul style="list-style-type: none; padding-left: 0;">
"""

for node_type, color in color_map.items():
    legend_html += f'<li><span style="display: inline-block; width: 20px; height: 20px; background-color: {color}; margin-right: 5px;"></span>{node_type.capitalize()}</li>'

legend_html += """
    </ul>
    <h4>Edge Types:</h4>
    <ul style="list-style-type: none; padding-left: 0;">
"""

for edge_type, color in edge_color_map.items():
    legend_html += f'<li><span style="display: inline-block; width: 20px; height: 20px; background-color: {color}; margin-right: 5px;"></span>{edge_type[0].capitalize()} - {edge_type[1].capitalize()}</li>'

legend_html += """
    </ul>
</div>
"""

# Add the legend to the HTML file
with open("data/beautiful_graph.html", "r") as file:
    content = file.read()
    content = content.replace("</body>", legend_html + "</body>")

with open("data/beautiful_graph.html", "w") as file:
    file.write(content)


