# Train a model on the data

In [None]:

from collections import Counter
import jsonlines
import pandas as pd
import numpy as np
from numpy.typing import ArrayLike
import os

import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
from torch_geometric.nn import GraphConv, to_hetero, HeteroConv

from jazz_graph.pyg_data.pyg_data import CreateTensors


In [None]:
models_dir = '/workspace/local_data/graph_parquet_proto'
create = CreateTensors(models_dir)

In [None]:
# TODO: report on the data a little more concreately.
# E.g., who are the hub nodes? How many nodes have > 50 edges.
# how many nodes have < 6 edges? All these, by type.
# Get really fancy and visualize a sub-graph.

def frequency_of_n_labels(data: HeteroData):
    """Return frequency of number of labels in the data, i.e., what percentage have 1 label, 0 labels, etc."""
    count_by_row = data['performance'].y.sum(dim=1)
    n_samples = data['performance'].y.shape[0]
    counter = Counter((int(x) for x in (count_by_row)))
    for i in range(len(counter)):
        count = counter[i]
        freq = count / n_samples
        print(f"Num samples with {i} labels: {freq:.3f}")

In [None]:
data = HeteroData()

def index_tensor(tensor):
    """Return 0, 1, 2... for each value in tensor. (An index.)

    When sampling graph nodes, we want a direct lookup of the node
    ids.
    """
    return torch.arange(0, tensor.size(0), dtype=torch.int64).reshape(-1, 1)

# This is a little clunky. The nodes are not expected to provide
# substantial feature information--the information is the graph.
data['performance'].x = index_tensor(create.performances())
data['song'].x = index_tensor(create.songs())
data['artist'].x = index_tensor(create.artists())

data['artist', 'performs', 'performance'].edge_index = create.artist_performance_edges()
data['performance', 'performing', 'song'].edge_index = create.performance_song_edges()
data['artist', 'composed', 'song'].edge_index = create.artist_song_edges()

data['performance'].y = create.labels()
data['performance'].train_mask = create.train_mask()
data['performance'].dev_mask = create.dev_mask()
data['performance'].test_mask = create.test_mask()

# data['artist', 'performs', 'performance'].edge_attr = <instrument>
data = ToUndirected()(data)


In [None]:
print(data)
print(
    f"The graph contains {'' if data.has_isolated_nodes() else 'no '}isolated nodes and",
    f"is {'directed' if data.is_directed() else 'undirected'}."
)
frequency_of_n_labels(data)
for style, count in (zip(create._labels.columns, data['performance'].y.sum(dim=0))):
    print(f"  {style}: {int(count) / create._labels.shape[0]:.1%}")
    # Easy Listening is probably a mislabel by modern standards.

In [None]:
class JazzGNN(nn.Module):
    def __init__(self, num_performances, num_artists, num_songs, hidden_dims, embed_dims, output_dims, metadata):
        super().__init__()

        self.performance_embed = nn.Embedding(num_performances, embed_dims)
        self.song_embed = nn.Embedding(num_songs, embed_dims)
        self.artist_embded = nn.Embedding(num_artists, embed_dims)

        self.conv1 = HeteroConv({
            key: GraphConv(embed_dims, hidden_dims) for key in metadata[1]
        })
        self.conv2 = HeteroConv({
            key: GraphConv(hidden_dims, hidden_dims) for key in metadata[1]
        })
        self.conv3 = HeteroConv({
            key: GraphConv(hidden_dims, hidden_dims) for key in metadata[1]
        })

        self.classifier = nn.Linear(hidden_dims, output_dims)

    def forward(self, x_dict, edge_dict) -> torch.Tensor:
        x_embedded = {
            'performance': self.performance_embed(x_dict['performance'].view(-1)),
            'artist': self.artist_embded(x_dict['artist'].view(-1)),
            'song': self.song_embed(x_dict['song'].view(-1))
        }

        x = self.conv1(x_embedded, edge_dict)
        x = {key: F.relu(val) for key, val in x.items()}
        x = self.conv2(x, edge_dict)
        x = {key: F.relu(val) for key, val in x.items()}
        x = self.conv3(x, edge_dict)

        logits = self.classifier(x['performance'])
        return logits

model = JazzGNN(
    data['performance'].num_nodes,
    data['artist'].num_nodes,
    data['song'].num_nodes,
    hidden_dims=128,
    embed_dims=64,
    output_dims=20,
    metadata=data.metadata()
)

data['performance'].num_nodes
data['artist'].num_nodes

model.performance_embed.weight.shape
model.artist_embded.weight.shape
model.song_embed.weight.shape

In [None]:
data['performance'].y
model(data.x_dict, data.edge_index_dict)
# data.x_dict['performance'].view(-1)