# Train a model on the data

In [None]:

from collections import Counter
import jsonlines
import pandas as pd
import numpy as np
from numpy.typing import ArrayLike
import os

import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
from torch_geometric.nn import GraphConv, SAGEConv, to_hetero, HeteroConv
from torch_geometric import transforms as T
from torch_geometric import seed_everything

from jazz_graph.data.utils import inspect_degrees
from jazz_graph.pyg_data.pyg_data import CreateTensors
from jazz_graph.model import JazzModel, LinkPredictionModel, NodeClassifier


In [None]:
models_dir = '/workspace/local_data/graph_parquet_proto'
create = CreateTensors(models_dir)

In [None]:
# TODO: report on the data a little more concreately.
# E.g., who are the hub nodes? How many nodes have > 50 edges.
# how many nodes have < 6 edges? All these, by type.
# Get really fancy and visualize a sub-graph.

def frequency_of_n_labels(data: HeteroData):
    """Return frequency of number of labels in the data, i.e., what percentage have 1 label, 0 labels, etc."""
    count_by_row = data['performance'].y.sum(dim=1)
    n_samples = data['performance'].y.shape[0]
    counter = Counter((int(x) for x in (count_by_row)))
    for i in range(len(counter)):
        count = counter[i]
        freq = count / n_samples
        print(f"Num samples with {i} labels: {freq:.3f}")

In [None]:
data = HeteroData()

def index_tensor(tensor):
    """Return 0, 1, 2... for each value in tensor. (An index.)

    When sampling graph nodes, we want a direct lookup of the node
    ids.
    """
    return torch.arange(0, tensor.size(0), dtype=torch.int64).reshape(-1, 1)

# This is a little clunky. The nodes are not expected to provide
# substantial feature information--the information is the graph.
data['performance'].x = index_tensor(create.performances())
data['song'].x = index_tensor(create.songs())
data['artist'].x = index_tensor(create.artists())

data['artist', 'performs', 'performance'].edge_index = create.artist_performance_edges()
data['performance', 'performing', 'song'].edge_index = create.performance_song_edges()
data['artist', 'composed', 'song'].edge_index = create.artist_song_edges()

data['performance'].y = create.labels()
data['performance'].train_mask = create.train_mask()
data['performance'].dev_mask = create.dev_mask()
data['performance'].test_mask = create.test_mask()

# data['artist', 'performs', 'performance'].edge_attr = <instrument>
data = ToUndirected()(data)


In [None]:
create.label_names()

In [None]:
print(data)
print(
    f"The graph contains {'' if data.has_isolated_nodes() else 'no '}isolated nodes and",
    f"is {'directed' if data.is_directed() else 'undirected'}."
)
frequency_of_n_labels(data)
for style, count in (zip(create.label_names(), data['performance'].y.sum(dim=0))):
    print(f"  {style}: {int(count) / create._labels.shape[0]:.1%}")
    # Easy Listening is probably a mislabel by modern standards.


In [None]:
inspect_degrees(data)

Unnamed: 0,performs,performing,composed,rev_performs,rev_performing,rev_composed
count,2583.0,10886.0,2583.0,10886.0,2129.0,2129.0
mean,12.656601,0.402535,1.696477,3.003123,2.058243,2.058243
std,22.430807,0.634782,7.352685,4.485468,2.449948,2.449948
min,0.0,0.0,0.0,0.0,1.0,1.0
25%,0.0,0.0,0.0,0.0,1.0,1.0
50%,6.0,0.0,0.0,1.0,1.0,1.0
75%,13.0,1.0,1.0,5.0,2.0,2.0
max,326.0,8.0,153.0,51.0,30.0,30.0


## Model

In [None]:
model = NodeClassifier(
    JazzModel(
        data['performance'].num_nodes,
        data['artist'].num_nodes,
        data['song'].num_nodes,
        hidden_dim=128,
        embed_dim=64,
        metadata=data.metadata()
    ),
    hidden_dim=128,
    num_classes=len(create.label_names())
)

data['performance'].num_nodes
data['artist'].num_nodes


## Train

In [None]:
from torch_geometric.loader import NeighborLoader

def train_indicies(mask):
    num_nodes = mask.shape[0]
    all_node_indicies = torch.arange(num_nodes)
    return all_node_indicies[mask]

train_loader = NeighborLoader(
    data,
    [15, 15, 15],
    batch_size=128,
    input_nodes=('performance', train_indicies(data['performance'].train_mask)),
)
dev_loader = NeighborLoader(
    data,
    [15, 15, 15],
    batch_size=128,
    input_nodes=('performance', train_indicies(data['performance'].dev_mask)),
)

### Metrics for training.

In [None]:
# TODO: set up some kind of proper training logs.
# You know yourself and you WILL end up running dozens
# of runs with small configuration tweaks. Without logs
# you have no way to recover what you did.


# I pulled this from some other course work that I did.
class ConfusionMatrix:
    """Confusion matrix calculator which can accumulate predictions and compute them.

    Currently supports binary labels in a multi-label configuration and generates
    n label one-versus-rest confusion matrices.
    """
    def __init__(self):
        # set TP, TN, FP, FN
        self.reset()

    def update(self, pred: torch.Tensor, labels: torch.Tensor):
        """Update the confusion matrix, for example, during each batch of training."""
        predicted = (pred > .5)
        self.true_positives = ((predicted == labels) & (labels == 1)).sum(dim=0) + self.true_positives
        self.false_positives = ((predicted != labels) & (labels == 1)).sum(dim=0) + self.false_positives
        self.true_negatives = ((predicted == labels) & (labels == 0)).sum(dim=0) + self.true_negatives
        self.false_negatives = ((predicted != labels) & (labels == 0)).sum(dim=0) + self.false_negatives
        # print("TP", self.true_positives)

    def compute(self) -> np.ndarray:
        """Compute the confusion matrix and return it.

        The result is a n-labels one-versus-rest binary confusion matrices
        with shape (2, 2, n_labels.)
        """
        out = np.zeros((2, 2, self.true_negatives.shape[-1]))
        out[0, 0] = self.true_negatives.cpu().numpy()
        out[0, 1] = self.false_negatives.cpu().numpy()
        out[1, 0] = self.false_positives.cpu().numpy()
        out[1, 1] = self.true_positives.cpu().numpy()
        return out

    def reset(self):
        """Reset the matrix to zeros, for example, after each complete epoch."""
        self.true_positives = torch.tensor(0)
        self.false_positives = torch.tensor(0)
        self.true_negatives = torch.tensor(0)
        self.false_negatives = torch.tensor(0)


def per_label_accuracy(confusion: ConfusionMatrix):
    confusion_ = confusion.compute()
    true_positives, false_positives, true_negatives, false_negatives = confusion_[1, 1], confusion_[1, 0], confusion_[0, 0], confusion_[0, 1]
    numer = true_positives + true_negatives
    return numer / (numer + false_positives + false_negatives)

def per_label_precision(confusion: ConfusionMatrix):
    confusion_ = confusion.compute()
    true_positives, false_positives, true_negatives, false_negatives = confusion_[1, 1], confusion_[1, 0], confusion_[0, 0], confusion_[0, 1]

    denom = (true_positives + false_positives)
    macro_raw = np.nan_to_num(true_positives / denom)
    macro = np.sum(macro_raw) / denom.shape[0]
    micro = true_positives.sum() / (true_positives.sum() + false_positives.sum())

    return micro, macro

    # class_counts = true_positives + false_positives + true_negatives + false_negatives
    # weighted_raw = np.nan_to_num((true_positives * class_counts) / (denom * class_counts))
    # weighted = np.sum(weighted_raw)
    # print(micro, macro, weighted)
    # return micro, macro, weighted


def per_label_recall(confusion: ConfusionMatrix):
    confusion_ = confusion.compute()
    true_positives, false_positives, true_negatives, false_negatives = confusion_[1, 1], confusion_[1, 0], confusion_[0, 0], confusion_[0, 1]

    denom = (true_positives + false_negatives)
    macro_raw = np.nan_to_num(true_positives / denom)
    macro = np.sum(macro_raw) / denom.shape[0]
    micro = true_positives.sum() / (true_positives.sum() + false_negatives.sum())

    return micro, macro

def format_float_arr(arr):
    strings = [f'{x:.3f}' for x in arr]
    return ', '.join(strings)

def batch_report(confusion: ConfusionMatrix, batch: int):
    print(f"Finished batch {batch}.")
    print(format_float_arr(per_label_accuracy(confusion)))
    confusion.reset()

def epoch_report(confusion: ConfusionMatrix, epoch: int):
    print(f"Finished epoch {epoch}.")
    print("  Accuracies: ", format_float_arr(per_label_accuracy(confusion)))
    print("  Recalls: ", format_float_arr(per_label_recall(confusion)))
    print("  Precisions: ", format_float_arr(per_label_precision(confusion)))
    confusion.reset()


class WeightedF1Score:
    def __init__(self):
        self.confusion = ConfusionMatrix()
        self.name = 'weighted_f1_score'

    def update(self, loss, pred, labels):
        self.confusion.update(loss, pred, labels)

    def reset(self):
        self.confusion.reset()

    def compute(self) -> np.ndarray:
        confusion = self.confusion.compute()
        true_positives, false_positives, true_negatives, false_negatives = confusion[1, 1], confusion[1, 0], confusion[0, 0], confusion[0, 1]
        f1 = (2 * true_positives) / (2 * true_positives + false_positives + false_negatives)
        weight = true_positives + false_negatives
        weighted_f1 = (f1 * weight) / (weight.sum())
        return np.sum(weighted_f1)


In [None]:
def test_confusion():
    confusion = ConfusionMatrix()
    predicitions = torch.tensor([
        [.1, .8, .2], [.2, .7, .1]
    ])
    labels = torch.tensor([
        [0, 1, 1], [0, 1, 0]
    ])
    confusion.update(predicitions, labels)
    np.testing.assert_array_equal(confusion.true_negatives, np.array([2, 0, 1]))
    np.testing.assert_array_equal(confusion.true_positives, np.array([0, 2, 0]))
    np.testing.assert_array_equal(confusion.false_negatives, np.array([0, 0, 0]))
    np.testing.assert_array_equal(confusion.false_positives, np.array([0, 0, 1]))

    matrix = confusion.compute()
    np.testing.assert_array_equal(matrix[0, 0], confusion.true_negatives)
    np.testing.assert_array_equal(matrix[0, 1], confusion.false_negatives)
    np.testing.assert_array_equal(matrix[1, 0], confusion.false_positives)
    np.testing.assert_array_equal(matrix[1, 1], confusion.true_positives)

test_confusion()

### Train function

In [None]:
from collections import defaultdict
def train(model: NodeClassifier, loader: NeighborLoader, dev_loader: NeighborLoader, epochs: int = 1):
    optimizer = torch.optim.Adam(model.parameters(), lr=.001)

    batch_confusion = ConfusionMatrix()
    epoch_confusion = ConfusionMatrix()
    losses = defaultdict(list)
    criterion = nn.BCEWithLogitsLoss()
    default_batch_size = loader.batch_size
    for epoch in range(epochs):
        model.train()
        batch_loss = 0
        n_samples = 0
        for i, batch in enumerate(loader):
            optimizer.zero_grad()
            batch_size = batch['performance'].batch_size
            y_hat = model(batch.x_dict, batch.edge_index_dict)[:batch_size]
            y = batch['performance'].y[:batch_size]
            loss = criterion(y_hat, y.to(torch.float))
            loss.backward()
            optimizer.step()

            batch_loss += loss.item() * batch_size
            n_samples += batch_size
            batch_confusion.update(y_hat, y)
            epoch_confusion.update(y_hat, y)

        losses['train'].append(batch_loss / n_samples)
        epoch_report(epoch_confusion, epoch)

        for i, batch in enumerate(dev_loader):
            model.eval()
            batch_loss = 0
            n_samples = 0
            batch_size = batch['performance'].batch_size
            with torch.no_grad():
                y_hat = model(batch.x_dict, batch.edge_index_dict)[:batch_size]
                y = batch['performance'].y[:batch_size]
                loss = criterion(y_hat, y.to(torch.float))

            batch_loss += loss.item() * batch_size
            n_samples += batch_size
            epoch_confusion.update(y_hat, y)
        losses['val_loss'].append(batch_loss / n_samples)
        print("Dev set results.")
        epoch_report(epoch_confusion, epoch)
    return losses

losses = train(model, train_loader, dev_loader, 15)


In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1)
for key, loss in losses.items():
    ax.plot(loss, label = key)


In [None]:
# Inspect a random sample of predictions:
confusion = ConfusionMatrix()
all_preds = []
for batch in dev_loader:
    with torch.no_grad():
        y_hat = F.sigmoid(model(batch.x_dict, batch.edge_index_dict))
        y = batch['performance'].y
    confusion.update(y_hat, y)
    all_preds.append(y_hat)

probs = np.concatenate(all_preds)
selections = probs > .5

In [None]:
confusion.compute()[1, 0]

In [None]:
# Sanity checks: are we making positive predictions? is the number of positive per sample close to 1?
print(f"A total of {(selections.sum(axis=1) > 0).sum()} of {selections.shape[0]} samples received some classification.")
print(pd.Series(selections.sum(axis=1)).value_counts())
selections.sum(axis=0)

## Edge Prediction

Quick shot at writing an edge prediction model. 
Conceptually, a recommender system based on this prediction "who worked with whom, on what?"
Thus, it's a lot like asking Theolious Monk "What would you recommend?"

In [None]:
seed_everything(42)
performs_edge_count = data[('artist', 'performs', 'performance')].num_edges

split_graph = T.RandomLinkSplit(
    num_val=int(performs_edge_count * .1),
    num_test=int(performs_edge_count * .1),
    disjoint_train_ratio=.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=('artist', 'performs', 'performance'),
    rev_edge_types=('performance', 'rev_performs', 'artist')
)
train_data, dev_data, test_data = split_graph(data)

In [None]:
from torch_geometric.loader import LinkNeighborLoader

def edge_training_data_factory(data: HeteroData) -> LinkNeighborLoader:
    edge_loader = LinkNeighborLoader(
        data=data,
        num_neighbors=[15, 15],
        neg_sampling_ratio=2.0,
        edge_label_index=(('artist', 'performs', 'performance'), train_data['performs'].edge_label_index),
        edge_label=None,
        batch_size=128,
        shuffle=True
    )
    return edge_loader

edge_loader_train = edge_training_data_factory(train_data)
edge_loader_dev = edge_training_data_factory(dev_data)

In [None]:
model = LinkPredictionModel(JazzModel(
    data['performance'].num_nodes,
    data['artist'].num_nodes,
    data['song'].num_nodes,
    hidden_dim=128,
    embed_dim=64,
    metadata=data.metadata()
))

In [None]:
batch = next(iter(edge_loader_train))
batch['performs'].edge_label
batch['performs'].edge_label
batch['performs'].edge_label_index.shape
model(batch.x_dict, batch.edge_index_dict, batch['performs'].edge_label_index).shape
batch

In [None]:
class GNNTrainingLogic:
    """Define training step and eval steps."""
    def __init__(self, model, optimizer, criterion):
        self.device = next(model.parameters()).device
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion

    def _extract_model_args(self, batch):
        return batch.x_dict, batch.edge_index_dict, batch['performs'].edge_label_index

    def train_step(self, engine, batch: HeteroData) -> dict:
        """Complete one step of gradient descent."""
        self.model.train()
        self.optimizer.zero_grad()
        batch.to(self.device)

        y_pred = self.model(*self._extract_model_args(batch))
        y_true = batch['performs'].edge_label
        loss = self.criterion(y_pred, y_true)
        loss.backward()
        self.optimizer.step()
        return {'loss': loss.item(), 'y_pred': y_pred.detach(), 'y_true': y_true.detach()}

    def eval_step(self, engine, batch: HeteroData) -> dict:
        """Complete one pass over a batch of data with no-grad and return results."""
        self.model.eval()
        batch.to(self.device)
        with torch.no_grad():
            y_pred = self.model(*self._extract_model_args(batch))
            y_true = batch['performs'].edge_label
        return {'y_pred': y_pred, 'y_true': y_true}


criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.001)
trainer_logic = GNNTrainingLogic(model, optimizer, criterion)
trainer_logic.train_step(None, batch) is not None


In [None]:
from ignite.engine import Engine, Events
from ignite.metrics import Recall, Precision, Accuracy, Loss

def log_training_results(trainer, evaluator, loader, step_name):
    evaluator.run(loader)
    metrics = evaluator.state.metrics
    print(f"{step_name} - Epoch[{trainer.state.epoch:03}]")
    for metric, value in metrics.items():
        print(f"  Avg. {metric}: {value:.3f}")

def binary_output_transform(output: dict[str, torch.Tensor]) -> tuple:
    """Return y_true and y_pred as binary classifications."""
    y_pred = (output["y_pred"] > 0).long()
    y_true = output["y_true"]
    return y_pred, y_true

accuracy = Accuracy(output_transform=binary_output_transform)

trainer = Engine(trainer_logic.train_step)
train_evaluator = Engine(trainer_logic.eval_step)
dev_evaluator = Engine(trainer_logic.eval_step)

metrics = {
    'accuracy': Accuracy(output_transform=binary_output_transform),
    'recall': Recall(output_transform=binary_output_transform),
    'precision': Precision(output_transform=binary_output_transform),
    'loss': Loss(criterion, output_transform=lambda out: (out['y_pred'], out['y_true']))
}

for name, metric in metrics.items():
    metric.attach(train_evaluator, name)
    metric.attach(dev_evaluator, name)

trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results, train_evaluator, edge_loader_train, "Training")
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results, dev_evaluator, edge_loader_dev, "Validation")


In [None]:
trainer.run(edge_loader_train, 4)