# Example notebook using Contrastive Learning for image classification

## Imports

In [None]:
import znnl as znnl
import jax.numpy as np
import numpy as onp
from flax import linen as nn
import optax
from neural_tangents import stax
import matplotlib.pyplot as plt
import pandas as pd

from papyrus.measurements import (
    Loss, Accuracy, NTKTrace, NTKEntropy, NTK, NTKSelfEntropy, NTKEigenvalues
)

## Data

In [None]:
data_generator = znnl.data.MNISTGenerator(100)
input_shape = data_generator.train_ds['inputs'][:1, ...].shape

## Isolated Potential Loss

### Create Model

In [None]:
class Architecture(nn.Module):
    """
    Simple Flax model.
    """
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(2)(x)

        return x
    
model = znnl.models.FlaxModel(
    flax_module=Architecture(),
    optimizer=optax.adam(learning_rate=0.005),
    input_shape=input_shape,
    seed=0,
)

### Set up Recorders

Here, four recorders are initialized. \
One for each potential and one to track the NTK and according observables. 

In [None]:
loss_fn = znnl.loss_functions.ContrastiveIsolatedPotentialLoss(
        attractive_pot_fn=znnl.loss_functions.MeanPowerLoss(order=2), 
        repulsive_pot_fn=znnl.loss_functions.ExponentialRepulsionLoss(), 
        external_pot_fn=znnl.loss_functions.ExternalPotential(), 
        turn_off_attractive_potential=False,
        turn_off_repulsive_potential=False,
        turn_off_external_potential=False,
    )

def attractive_loss(point1, point2): return loss_fn.compute_losses(point1, point2)[0]
def repulsive_loss(point1, point2): return loss_fn.compute_losses(point1, point2)[1]
def external_loss(point1, point2): return loss_fn.compute_losses(point1, point2)[2]

In [None]:
attractive_recorder = znnl.training_recording.JaxRecorder(
    name="attractive_recorder",
    storage_path='.',
    # loss=True, 
    measurements=[Loss(apply_fn=attractive_loss)],
    update_rate=1, 
)
attractive_recorder.instantiate_recorder(
    data_set=data_generator.train_ds, 
    model=model
)

repulsive_recorder = znnl.training_recording.JaxRecorder(
    name="repulsive_recorder",
    storage_path='.',
    measurements=[Loss(apply_fn=repulsive_loss)],
    update_rate=1, 
)
repulsive_recorder.instantiate_recorder(
    data_set=data_generator.train_ds, 
    model=model
)

external_recorder = znnl.training_recording.JaxRecorder(
    name="external_recorder",
    storage_path='.',
    measurements=[Loss(apply_fn=external_loss)],
    update_rate=1, 
)
external_recorder.instantiate_recorder(
    data_set=data_generator.train_ds, 
    model=model
)

ntk_recorder = znnl.training_recording.JaxRecorder(
    name="nrk_recorder",
    storage_path='.',
    measurements=[NTKTrace(), NTKEntropy()],
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
ntk_computation = znnl.analysis.JAXNTKComputation(
    apply_fn=model.ntk_apply_fn, 
    batch_size=10,
)
ntk_recorder.instantiate_recorder(
    data_set=data_generator.train_ds, 
    ntk_computation=ntk_computation, 
    model=model
)

recorders = [
    attractive_recorder, 
    repulsive_recorder, 
    external_recorder, 
    # ntk_recorder
]

### Initialize the Trainer

In [None]:
trainer = znnl.training_strategies.SimpleTraining(
    model=model,
    loss_fn=loss_fn,
    recorders=recorders, 
    seed=0,
)

In [None]:
trainer.loss_fn.turn_off_attractive_potential, trainer.loss_fn.turn_off_repulsive_potential, trainer.loss_fn.turn_off_external_potential

### Execute Training

In [None]:
batched_loss = trainer.train_model(
    train_ds=data_generator.train_ds, 
    test_ds=data_generator.test_ds, 
    epochs=100, 
    batch_size=100, 
)

### Evaluate Training

In [None]:
attractive_results = attractive_recorder.gather()
repulsive_results = repulsive_recorder.gather()
external_results = external_recorder.gather()

ntk_results = ntk_recorder.gather()

### Plot the losses over epochs

In [None]:
plt.figure(figsize=(8, 5))

plt.plot(attractive_results['loss'], label="attractive")
plt.plot(repulsive_results['loss'], label="repulsive")
plt.plot(external_results['loss'], label="external")

plt.legend()
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss contributions during training')
plt.show()

### Plot the numbers in the output space

In [None]:
# Get labels as ints
labels = np.argmax(data_generator.train_ds['targets'], axis=1)

# Calculate representations of the final model
out_x, out_y = model(data_generator.train_ds["inputs"]).T

# Combine labels and representaions into dataframe
df = pd.DataFrame(zip(out_x, out_y, labels), columns=["x", "y", "label"])

# Plot figure
fig, ax = plt.subplots(figsize=(8,5))
scatter = ax.scatter(df['x'], df['y'], c=df['label'], label=df['label'], cmap='tab10')
legend1 = ax.legend(*scatter.legend_elements(num=10), loc="best", title="Number", framealpha=0)
ax.add_artist(legend1)

plt.xlabel('x')
plt.ylabel('y')
plt.title('Output Space')
plt.show()

## InfoNCE Loss

### Create Model

In [None]:
class Architecture(nn.Module):
    """
    Simple CNN module.
    """

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(2)(x)
        # Use L2 normalization to ensure that the output is on the unit sphere
        x = x / np.linalg.norm(x, axis=1, keepdims=True)
        return x
    
model = znnl.models.FlaxModel(
    flax_module=Architecture(),
    optimizer=optax.adam(learning_rate=0.001),
    input_shape=input_shape,
    seed=0,
)

### Set up Recorders

In [None]:
train_recorder = znnl.training_recording.JaxRecorder(
    name="train_recorder",
    storage_path='.',
    # loss=True, 
    measurements=[
        Loss(apply_fn=znnl.loss_functions.ContrastiveInfoNCELoss(temperature=0.05))
    ],
    update_rate=1, 
)
train_recorder.instantiate_recorder(
    data_set=data_generator.train_ds, 
    model=model
)

test_recorder = znnl.training_recording.JaxRecorder(
    name="test_recorder",
    storage_path='.',
    measurements=[
        Loss(apply_fn=znnl.loss_functions.ContrastiveInfoNCELoss(temperature=0.05))
    ],
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
test_recorder.instantiate_recorder(
    data_set=data_generator.test_ds, 
    model=model
)

ntk_recorder = znnl.training_recording.JaxRecorder(
    name="nrk_recorder",
    storage_path='.',
    measurements=[NTKTrace(), NTKEntropy()],
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
ntk_computation = znnl.analysis.JAXNTKComputation(
    apply_fn=model.ntk_apply_fn, 
    batch_size=10,
)
ntk_recorder.instantiate_recorder(
    data_set=data_generator.train_ds, 
    model=model,
    ntk_computation=ntk_computation
)

recorders = [
    train_recorder, 
    test_recorder, 
    # ntk_recorder
]

### Initialize the Trainer

In [None]:
trainer = znnl.training_strategies.SimpleTraining(
    model=model,
    loss_fn=znnl.loss_functions.ContrastiveInfoNCELoss(
        temperature=0.05
    ),
    recorders=recorders, 
    seed=0,
)

### Execute Training

In [None]:
batched_loss = trainer.train_model(
    train_ds=data_generator.train_ds, 
    test_ds=data_generator.test_ds, 
    epochs=50, 
    batch_size=20, 
)

### Evaluate Training

In [None]:
train_results = train_recorder.gather()
test_results = test_recorder.gather()

ntk_results = ntk_recorder.gather()

### Plot the losses over epochs

In [None]:
plt.figure(figsize=(8, 5))

plt.plot(train_results['loss'], label="train")
plt.plot(test_results['loss'], label="test")

plt.legend()
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss contributions during training')
plt.show()

### Plot the numbers in the output space

In [None]:
# Get labels as ints
labels = np.argmax(data_generator.train_ds['targets'], axis=1)

# Calculate representations of the final model
out_x, out_y = model(data_generator.train_ds["inputs"]).T

# Combine labels and representaions into dataframe
df = pd.DataFrame(zip(out_x, out_y, labels), columns=["x", "y", "label"])

# Plot figure
fig, ax = plt.subplots(figsize=(8,5))
scatter = ax.scatter(df['x'], df['y'], c=df['label'], label=df['label'], cmap='tab10')
legend1 = ax.legend(*scatter.legend_elements(num=10), loc="best", title="Number", framealpha=0)
ax.add_artist(legend1)

plt.xlabel('x')
plt.ylabel('y')
plt.title('Output Space')
plt.show()