# Computing Collective Variables of the NTK for each Class

This Notebook is a continuation of the notebook `Computing-NTK-Collective-Variables.ipynb`. 
In this notebook, we will compute the collective variables of the NTK for each class individually. 
This allows one to study the neural network training in a class-specific manner. 

Moreover, this script also illustrates how to compute entropies between labels each class via the NTK. 

These aspects will be shown using the MNIST dataset.


In [None]:
import znnl as nl

import tensorflow_datasets as tfds

import numpy as np
from flax import linen as nn
import optax

from jax.lib import xla_bridge

import matplotlib.pyplot as plt

print(f"Using: {xla_bridge.get_backend().platform}")

### Define the Datasets

In [None]:
data_generator = nl.data.MNISTGenerator(500)


# Create NTK data set by selecting only classes 0 to 3
inputs = data_generator.train_ds["inputs"]
targets = data_generator.train_ds["targets"]

selection_idx_012 = np.where(np.argmax(targets, axis=1) < 3)[0]
selection_idx_0 = np.where(np.argmax(targets, axis=1) == 7)[0]
selection_idx = np.concatenate((selection_idx_012, selection_idx_0))
selection_idx = np.random.permutation(selection_idx)
ntk_inputs = inputs[selection_idx]
ntk_targets = targets[selection_idx]

# Quick check
print(f"Original data set shape: {inputs.shape}")
assert np.argmax(ntk_targets, axis=1).max() == 7 #3

# Set new attribute ntk_ds
ntk_ds = {
    "inputs": ntk_inputs[:50],
    "targets": ntk_targets[:50]
}
setattr(data_generator, "ntk_ds", ntk_ds)
print(f"NTK data set shape: {data_generator.ntk_ds['inputs'].shape}")

### Setup the Model

In [None]:
class Architecture(nn.Module):
    """
    Simple CNN module.
    """
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=32)(x)
        x = nn.relu(x)
        x = nn.Dense(features=32)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

In [None]:
model = nl.models.FlaxModel(
    flax_module=Architecture(),
    optimizer=optax.sgd(learning_rate=0.01, momentum=0.9),
    input_shape=(1, 28, 28, 1),
    trace_axes=(),
)

In [None]:
ntk_recorder = nl.training_recording.JaxRecorder(
    name="ntk_recorder",
    class_specific=True,
    entropy_class_correlations=True, 
    ntk=True,
    loss=True,
    covariance_entropy=True,
    magnitude_variance=True, 
    trace=True,
    eigenvalues=True,
    loss_derivative=True,
    update_rate=1, 
    chunk_size=1e5
)
ntk_recorder.instantiate_recorder(
    data_set=data_generator.ntk_ds
)

train_recorder = nl.training_recording.JaxRecorder(
    name="train_recorder",
    loss=True,
    update_rate=1, 
    chunk_size=1e5
)
train_recorder.instantiate_recorder(
    data_set=data_generator.train_ds
)

test_recorder = nl.training_recording.JaxRecorder(
    name="test_recorder",
    accuracy=True,
    loss=True,
    update_rate=1, 
    chunk_size=1e5
)
test_recorder.instantiate_recorder(
    data_set=data_generator.test_ds
)

In [None]:
training_strategy = nl.training_strategies.SimpleTraining(
    model=model, 
    loss_fn=nl.loss_functions.CrossEntropyLoss(),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
    recorders=[train_recorder, test_recorder, ntk_recorder],
)

In [None]:
# Show labels-combinations used to compute the inter-class entropy contributions.  
ntk_recorder._class_combinations, ntk_recorder._class_idx

In [None]:
batched_training_metrics = training_strategy.train_model(
    train_ds=data_generator.train_ds, 
    test_ds=data_generator.test_ds,
    batch_size=128,
    epochs=150,
)

### Checking the Results

In [None]:
train_report = train_recorder.gather_recording()
test_report = test_recorder.gather_recording()
ntk_report = ntk_recorder.gather_recording()

In [None]:
# Plot loss and accuracy

fig, ax = plt.subplots(1, 2, figsize=(12, 4))

ax[0].plot(train_report.loss, label="Train")
ax[0].plot(test_report.loss, label="Test")
ax[0].set_title("Loss")
ax[0].legend()
ax[0].set_yscale("log")

ax[1].plot(test_report.accuracy, label="Test")
ax[1].set_title("Accuracy")
ax[1].legend()




In [None]:
# Read class specific (cs) data from the ntk_recorder

trace_cs = ntk_recorder.read_class_specific_data(ntk_report.trace)
covariance_entropy_cs = ntk_recorder.read_class_specific_data(ntk_report.covariance_entropy)

# Reading out the eigenvalues works the same way
eigenvalues_cs = ntk_recorder.read_class_specific_data(ntk_report.eigenvalues) 


# Plot the class specific entropy and trace 

cmap = plt.get_cmap('rainbow')

fig, axs = plt.subplots(1, 2, figsize=(12, 4))

for i, l in enumerate([0, 1, 2, 7]):
    axs[0].plot(covariance_entropy_cs[l], '-', mfc='Entropy', label=f"Train {l}", color=cmap(i/3))
    axs[1].plot(trace_cs[l], '--', mfc='Trace', label=f"Train {l}", color=cmap(i/3))

axs[0].set_xlabel("Epoch")
axs[1].set_xlabel("Epoch")
axs[0].set_ylabel("Entropy")
axs[1].set_ylabel("Trace")

# Colorbar with integer ticks
cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=axs, orientation='vertical')
cbar.set_ticks([0, 1/3, 2/3, 1])
cbar.set_ticklabels([0, 1, 2, 7])

plt.show()


# 7 Is missing!!! Why?

In [None]:
# Plot entropy contributions of each class, as well as between class correlations

cmap = plt.get_cmap('rainbow')

entropies = np.array(ntk_recorder._entropy_class_correlations_array)

fig, axs = plt.subplots(1, 2, figsize=(12, 4), tight_layout=True)

# for i in range(3):
#     axs[0].plot(entropies[:, i+4] - entropies[:, i], '-',  label=f"S(0 - {i+1})", color=cmap(i/2))
#     # axs[0].plot(entropies[:, i], '--',  label=f"S({i})", color=cmap(i/2))
#     # axs[0].plot(entropies[:, i+4], '_',  label=f"S(0 + {i+1})", color=cmap(i/2), alpha=0.4)

axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Entropy")
# # Put legend outside of plot
# axs[0].legend(loc='center left', bbox_to_anchor=(1, 0.5))

axs[1].plot(entropies[:, :4].sum(axis=1)/4, '-', label=r'$S_{sub} = S(1) + S(2) + S(3) + S(4)$')
axs[1].plot(entropies[:, -1], '--', label=r'$S_{sys} = S(1 + 2 + 3 + 4)$')
axs[0].plot(entropies[:, :4].sum(axis=1)/4 - entropies[:,-1], '--', label=r'$S_{sub} - S_{sys}$')
axs[0].legend()
axs[1].set_xlabel("Epoch")  
axs[1].legend()

In [None]:
# Plot the mutual information between the classes

cmap = plt.get_cmap('rainbow')

fig, axs = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True, sharey=True)

# axs[0].plot(entropies[:, 0] + entropies[:, i] - entropies[:, i+4], '-',  label=f"Mutual Info (0, {i+1})", color=cmap(i/2))

axs[0].plot((entropies[:, 0] + entropies[:, 1])/2 - entropies[:, 4], '-',  label=f"Mutual Info (0, 1)", color=cmap(0))
axs[0].plot((entropies[:, 0] + entropies[:, 2])/2 - entropies[:, 5], '-',  label=f"Mutual Info (0, 2)", color=cmap(1/3))
axs[0].plot((entropies[:, 0] + entropies[:, 3])/2 - entropies[:, 6], '-',  label=f"Mutual Info (0, 7)", color=cmap(2/3))
axs[0].legend()

axs[1].plot((entropies[:, 1] + entropies[:, 2])/2 - entropies[:, 7], '-',  label=f"Mutual Info (1, 2)", color=cmap(0))
axs[1].plot((entropies[:, 1] + entropies[:, 3])/2 - entropies[:, 8], '-',  label=f"Mutual Info (1, 7)", color=cmap(1/3))
axs[1].plot((entropies[:, 0] + entropies[:, 1])/2 - entropies[:, 4], '-',  label=f"Mutual Info (0, 1)", color=cmap(2/3))
axs[1].legend()

axs[2].plot((entropies[:, 2] + entropies[:, 3])/2 - entropies[:, 9], '-',  label=f"Mutual Info (2, 7)", color=cmap(0))
axs[2].plot((entropies[:, 0] + entropies[:, 2])/2 - entropies[:, 5], '-',  label=f"Mutual Info (0, 2)", color=cmap(1/3))
axs[2].plot((entropies[:, 1] + entropies[:, 3])/2 - entropies[:, 8], '-',  label=f"Mutual Info (1, 7)", color=cmap(2/3))
axs[2].legend()

axs[0].set_xlabel("Epoch")
axs[1].set_xlabel("Epoch")
axs[2].set_xlabel("Epoch")
axs[0].set_ylabel("Mutual Information")
