# Train a model on the data

In [None]:

from collections import Counter
import jsonlines
import pandas as pd
import numpy as np
from numpy.typing import ArrayLike
import os

from ignite.engine import Engine, Events
from ignite.metrics import Recall, Precision, Accuracy, Loss

import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
from torch_geometric.nn import GraphConv, SAGEConv, to_hetero, HeteroConv
from torch_geometric import transforms as T
from torch_geometric import seed_everything

from jazz_graph.data.utils import inspect_degrees
from jazz_graph.logging import JSONRunLogger, run_evaluator, log_experiment, binary_output_transform
from jazz_graph.pyg_data.pyg_data import CreateTensors
from jazz_graph.model import JazzModel, LinkPredictionModel, NodeClassifier


In [None]:
models_dir = '/workspace/local_data/graph_parquet_proto'
create = CreateTensors(models_dir)

In [None]:
# TODO: report on the data a little more concreately.
# E.g., who are the hub nodes? How many nodes have > 50 edges.
# how many nodes have < 6 edges? All these, by type.
# Get really fancy and visualize a sub-graph.

def frequency_of_n_labels(data: HeteroData):
    """Return frequency of number of labels in the data, i.e., what percentage have 1 label, 0 labels, etc."""
    count_by_row = data['performance'].y.sum(dim=1)
    n_samples = data['performance'].y.shape[0]
    counter = Counter((int(x) for x in (count_by_row)))
    for i in range(len(counter)):
        count = counter[i]
        freq = count / n_samples
        print(f"Num samples with {i} labels: {freq:.3f}")

In [None]:
data = HeteroData()

def index_tensor(tensor):
    """Return 0, 1, 2... for each value in tensor. (An index.)

    When sampling graph nodes, we want a direct lookup of the node
    ids.
    """
    return torch.arange(0, tensor.size(0), dtype=torch.int64).reshape(-1, 1)

# This is a little clunky. The nodes are not expected to provide
# substantial feature information--the information is the graph.
data['performance'].x = index_tensor(create.performances())
data['song'].x = index_tensor(create.songs())
data['artist'].x = index_tensor(create.artists())

data['artist', 'performs', 'performance'].edge_index = create.artist_performance_edges()
data['performance', 'performing', 'song'].edge_index = create.performance_song_edges()
data['artist', 'composed', 'song'].edge_index = create.artist_song_edges()

data['performance'].y = create.labels()
data['performance'].train_mask = create.train_mask()
data['performance'].dev_mask = create.dev_mask()
data['performance'].test_mask = create.test_mask()

# data['artist', 'performs', 'performance'].edge_attr = <instrument>
data = ToUndirected()(data)


In [None]:
create.label_names()

In [None]:
print(data)
print(
    f"The graph contains {'' if data.has_isolated_nodes() else 'no '}isolated nodes and",
    f"is {'directed' if data.is_directed() else 'undirected'}."
)
frequency_of_n_labels(data)
for style, count in (zip(create.label_names(), data['performance'].y.sum(dim=0))):
    print(f"  {style}: {int(count) / create._labels.shape[0]:.1%}")
    # Easy Listening is probably a mislabel by modern standards.


In [None]:
inspect_degrees(data)

## Model

## Train Style Classifier

In [None]:
from torch_geometric.loader import NeighborLoader

def train_indicies(mask):
    num_nodes = mask.shape[0]
    all_node_indicies = torch.arange(num_nodes)
    return all_node_indicies[mask]

train_loader = NeighborLoader(
    data,
    [15, 15, 15],
    batch_size=128,
    input_nodes=('performance', train_indicies(data['performance'].train_mask)),
)
dev_loader = NeighborLoader(
    data,
    [15, 15, 15],
    batch_size=128,
    input_nodes=('performance', train_indicies(data['performance'].dev_mask)),
)

### Train function

In [None]:
class GNNTrainingLogic:
    """Define training step and eval steps."""
    def __init__(self, model, optimizer, criterion):
        self.device = next(model.parameters()).device
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion

    def train_step(self, engine, batch: HeteroData) -> dict:
        """Complete one step of gradient descent."""
        self.model.train()
        self.optimizer.zero_grad()
        batch.to(self.device)
        batch_size = batch['performance'].batch_size
        y_pred = model(batch.x_dict, batch.edge_index_dict)[:batch_size]
        y_true = batch['performance'].y[:batch_size].to(torch.float)
        loss = self.criterion(y_pred, y_true.to(torch.float))  # TODO: need to cast here is just an artifact of wrong type in data.
        loss.backward()
        self.optimizer.step()
        return {'loss': loss.item(), 'y_pred': y_pred.detach(), 'y_true': y_true.detach()}

    def eval_step(self, engine, batch: HeteroData) -> dict:
        """Complete one pass over a batch of data with no-grad and return results."""
        self.model.eval()
        batch.to(self.device)

        batch_size = batch['performance'].batch_size
        with torch.no_grad():
            y_pred = model(batch.x_dict, batch.edge_index_dict)[:batch_size]
            y_true = batch['performance'].y[:batch_size].to(torch.float)
            loss = self.criterion(y_pred, y_true.to(torch.float))
        return {'y_pred': y_pred, 'y_true': y_true}

model_config = {
    'hidden_dim': 128,
    'embed_dim': 64,
    'dropout': 0.2
}

model = NodeClassifier(
    JazzModel(
        data['performance'].num_nodes,
        data['artist'].num_nodes,
        data['song'].num_nodes,
        hidden_dim=model_config['hidden_dim'],
        embed_dim=model_config['embed_dim'],
        dropout=model_config['dropout'],
        metadata=data.metadata()
    ),
    hidden_dim=model_config['hidden_dim'],
    num_classes=len(create.label_names())
)


experiment_config = {
    'model': model_config,
    'lr': .001,
    'batch_size': train_loader.batch_size,
    'epochs': 15
}


criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=experiment_config['lr'])
trainer_logic = GNNTrainingLogic(model, optimizer, criterion)
# assert trainer_logic.train_step(None, batch) is not None, "This is really just checking that the forward pass works."

experiment_logger = JSONRunLogger(run_name='gnn_classifier_slim_data', config=experiment_config)

accuracy = Accuracy(output_transform=binary_output_transform)

trainer = Engine(trainer_logic.train_step)
train_evaluator = Engine(trainer_logic.eval_step)
dev_evaluator = Engine(trainer_logic.eval_step)

metrics = {
    'accuracy': Accuracy(output_transform=binary_output_transform),
    'recall': Recall(output_transform=binary_output_transform),
    'precision': Precision(output_transform=binary_output_transform),
    'loss': Loss(criterion, output_transform=lambda out: (out['y_pred'], out['y_true']))
}

for name, metric in metrics.items():
    metric.attach(train_evaluator, name)
    metric.attach(dev_evaluator, name)

trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluator, train_evaluator, train_loader, "Training")
trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluator, dev_evaluator, dev_loader, "Validation")
train_evaluator.add_event_handler(Events.EPOCH_COMPLETED, log_experiment, experiment_logger, 'train', trainer)
dev_evaluator.add_event_handler(Events.EPOCH_COMPLETED, log_experiment, experiment_logger, 'dev', trainer)

In [None]:
trainer.run(train_loader, max_epochs=15)

## Edge Prediction

Quick shot at writing an edge prediction model. 
Conceptually, a recommender system based on this prediction "who worked with whom, on what?"
Thus, it's a lot like asking Theolious Monk "What would you recommend?"

In [None]:
seed_everything(42)
performs_edge_count = data[('artist', 'performs', 'performance')].num_edges

split_graph = T.RandomLinkSplit(
    num_val=int(performs_edge_count * .1),
    num_test=int(performs_edge_count * .1),
    disjoint_train_ratio=.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=('artist', 'performs', 'performance'),
    rev_edge_types=('performance', 'rev_performs', 'artist')
)
train_data, dev_data, test_data = split_graph(data)

In [None]:
from torch_geometric.loader import LinkNeighborLoader

def edge_training_data_factory(data: HeteroData) -> LinkNeighborLoader:
    edge_loader = LinkNeighborLoader(
        data=data,
        num_neighbors=[15, 15],
        neg_sampling_ratio=2.0,
        edge_label_index=(('artist', 'performs', 'performance'), train_data['performs'].edge_label_index),
        edge_label=None,
        batch_size=128,
        shuffle=True
    )
    return edge_loader

edge_loader_train = edge_training_data_factory(train_data)
edge_loader_dev = edge_training_data_factory(dev_data)

In [None]:
model_config = {
    'hidden_dim': 128,
    'embed_dim': 64,
    'dropout': 0.
}

model = LinkPredictionModel(JazzModel(
    data['performance'].num_nodes,
    data['artist'].num_nodes,
    data['song'].num_nodes,
    hidden_dim=model_config['hidden_dim'],
    embed_dim=model_config['embed_dim'],
    dropout=model_config['dropout'],
    metadata=data.metadata()
))

In [None]:
batch = next(iter(edge_loader_train))
batch['performs'].edge_label
batch['performs'].edge_label
batch['performs'].edge_label_index.shape
model(batch.x_dict, batch.edge_index_dict, batch['performs'].edge_label_index).shape
# batch

In [None]:
class GNNTrainingLogic:
    """Define training step and eval steps."""
    def __init__(self, model, optimizer, criterion):
        self.device = next(model.parameters()).device
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion

    def _extract_model_args(self, batch):
        return batch.x_dict, batch.edge_index_dict, batch['performs'].edge_label_index

    def train_step(self, engine, batch: HeteroData) -> dict:
        """Complete one step of gradient descent."""
        self.model.train()
        self.optimizer.zero_grad()
        batch.to(self.device)

        y_pred = self.model(*self._extract_model_args(batch))
        y_true = batch['performs'].edge_label
        loss = self.criterion(y_pred, y_true)
        loss.backward()
        self.optimizer.step()
        return {'loss': loss.item(), 'y_pred': y_pred.detach(), 'y_true': y_true.detach()}

    def eval_step(self, engine, batch: HeteroData) -> dict:
        """Complete one pass over a batch of data with no-grad and return results."""
        self.model.eval()
        batch.to(self.device)
        with torch.no_grad():
            y_pred = self.model(*self._extract_model_args(batch))
            y_true = batch['performs'].edge_label
        return {'y_pred': y_pred, 'y_true': y_true}

experiment_config = {
    'model': model_config,
    'lr': .001,
    'batch_size': edge_loader_train.batch_size,
    'epochs': 15
}


criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=experiment_config['lr'])
trainer_logic = GNNTrainingLogic(model, optimizer, criterion)
trainer_logic.train_step(None, batch) is not None


In [None]:
experiment_logger = JSONRunLogger(run_name='gnn_slim_data', config=experiment_config)

accuracy = Accuracy(output_transform=binary_output_transform)

trainer = Engine(trainer_logic.train_step)
train_evaluator = Engine(trainer_logic.eval_step)
dev_evaluator = Engine(trainer_logic.eval_step)

metrics = {
    'accuracy': Accuracy(output_transform=binary_output_transform),
    'recall': Recall(output_transform=binary_output_transform),
    'precision': Precision(output_transform=binary_output_transform),
    'loss': Loss(criterion, output_transform=lambda out: (out['y_pred'], out['y_true']))
}

for name, metric in metrics.items():
    metric.attach(train_evaluator, name)
    metric.attach(dev_evaluator, name)

trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluator, train_evaluator, edge_loader_train, "Training")
trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluator, dev_evaluator, edge_loader_dev, "Validation")
train_evaluator.add_event_handler(Events.EPOCH_COMPLETED, log_experiment, experiment_logger, 'train', trainer)
dev_evaluator.add_event_handler(Events.EPOCH_COMPLETED, log_experiment, experiment_logger, 'dev', trainer)


In [None]:
trainer.run(edge_loader_train, experiment_config['epochs'])

In [None]:
trainer.run(edge_loader_train, 4)