# Computing the Neural Mutual Information (MI) 

In this notebook we will show how to compute the Neural Mutual Information (NMI) between classes of data during training. 

In [None]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import znnl as nl
from flax import linen as nn
import optax

from papyrus.measurements import (
    Loss, Accuracy, NTKEntropy
)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
jax.default_backend()

For the demo of ZnNL, we will reduce the number of data points used for training and computing the Mutual Informtaion
To scale the computation, just increase the selected number of data points.

In [None]:
num_train = 100
num_nmi_per_class = 2

### Data generators

For the sake of covereage, we will look at the NTK properties of the Fuel data set for a small model.

In [None]:
data_generator = nl.data.MNISTGenerator(num_train)

###  Networks and Models

Now we can define the network architectures for which we will compute the NTK of the data.

The batch size defined in the model class refers to the batching in the NTK calculation. When calculating the NTK, the number of data points used in that calculation must be an integer mutliple of the batch size. 

In [None]:
class DenseModule(nn.Module):
    """
    Simple CNN module.
    """
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=32)(x)
        x = nn.relu(x)
        x = nn.Dense(features=32)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

In [None]:
model = nl.models.FlaxModel(
    flax_module=DenseModule(),
    optimizer=optax.sgd(learning_rate=0.005, momentum=0.9),
    input_shape=(1, 28, 28, 1),
)

### Recording 

We will record the loss and accuracy of the train and test data sets during training to see how well the model is learning.

In [None]:
train_recorder = nl.training_recording.JaxRecorder(
    name="train_recorder",
    measurements=[
        Loss(name="loss", apply_fn=nl.loss_functions.CrossEntropyLoss()),
        Accuracy(name="accuracy", apply_fn=nl.accuracy_functions.LabelAccuracy()),
    ],
    storage_path=".",
    update_rate=1, 
    chunk_size=1e5
)
train_recorder.instantiate_recorder(
    data_set=data_generator.train_ds, 
    model=model
)


test_recorder = nl.training_recording.JaxRecorder(
    name="test_recorder",
    measurements=[
        Loss(name="loss", apply_fn=nl.loss_functions.CrossEntropyLoss()),
        Accuracy(name="accuracy", apply_fn=nl.accuracy_functions.LabelAccuracy()),
    ],
    storage_path=".",
    update_rate=1, 
    chunk_size=1e5
)
test_recorder.instantiate_recorder(
    data_set=data_generator.test_ds, 
    model=model
)

### Computing the Neural MI

In order to compute the Neural MI, we will need to compute the von Neumann Entropy of the NTK. 
We will create a subset of the training data to compute the NTK.

In [None]:
n_points = num_nmi_per_class

mni_ds = {"inputs": [], "targets": []}

for i in [0, 1, 8]:
    idx = np.argmax(data_generator.train_ds['targets'], axis=1) == i
    num_points = np.sum(idx)
    print(f"Number of data points for class {i}: {num_points} of which {n_points} will be selected")

    # Select the first n_points of class i
    mni_ds["inputs"].extend(data_generator.train_ds['inputs'][idx][:n_points])
    mni_ds["targets"].extend(data_generator.train_ds['targets'][idx][:n_points])

mni_ds = {k: np.array(v) for k, v in mni_ds.items()}
print(f"Total number of data points to record the Mutual Information: {len(mni_ds['inputs'])}")


The Mutual Information is a quantity that measures the amount of information that one distribution has about another. 
In our case, we are interested in the amount of information that that one class of data has about another.

Since comparing all classes to all other classes is overly complicated, we will compare classes [0, 1, 8].

In [None]:
ntk_combintaion_computation = nl.analysis.JAXNTKCombinations(
    apply_fn=model.ntk_apply_fn, 
    class_labels=[0, 1, 8], # Selecting the classes to compute the Neural MI for
    batch_size=10,
)
mni_recorder = nl.training_recording.JaxRecorder(
    name="mni_recorder",
    measurements=[
        NTKEntropy(name="ntk_entropy", effective=False, normalize_eigenvalues=True),
    ],
    storage_path=".",
    update_rate=1,
    chunk_size=1e5
)
mni_recorder.instantiate_recorder(
    data_set=mni_ds, 
    ntk_computation=ntk_combintaion_computation
)


In [None]:
trainer = nl.training_strategies.SimpleTraining(
    model=model, 
    loss_fn=nl.loss_functions.CrossEntropyLoss(),
    recorders=[
        train_recorder, 
        test_recorder, 
        mni_recorder
    ],
)

In [None]:
batched_training_metrics = trainer.train_model(
    train_ds=data_generator.train_ds, 
    test_ds=data_generator.test_ds,
    batch_size=10,
    epochs=100,
)

### Checking the results

In [None]:
train_report = train_recorder.gather()
test_report = test_recorder.gather()
mni_report = mni_recorder.gather()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4))

axs[0].plot(train_report["loss"], label="train")
axs[0].plot(test_report["loss"], label="test")
axs[0].set_yscale("log")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")

axs[1].plot(train_report["accuracy"], label="train")
axs[1].plot(test_report["accuracy"], label="test")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Accuracy")

plt.show()


### Compute the Nerual MI

*To obtain results in the following part, you need to uncomment the `nmi_recorder` in when defining the `trainer` object above.*


Using the `JAXNTKCombinations` module, we obtain a set of entropy values for all the combinations of the classes. We can then compute the Neural MI using the entropy values.

The mutual information two correlated subsystems is obtained by:

$$I(X;Y) = S(X) + S(Y) - S(X,Y)$$

where $S(X)$ is the entropy of the first subsystem, $S(Y)$ is the entropy of the second subsystem, and $S(X,Y)$ is the joint entropy of the two subsystems.
Using this formula, we can compute the Mutual Information of the classes of the data.
The value of $$I(X;Y)$$ will however, depend on the size of the entropy values. For that reason we will normalize the Mutual Information by the sum of the entropies of the two classes:

$$I(X;Y) = \frac{2 \cdot I(X;Y)}{S(X) + S(Y)} \in [0, 1]$$


In [None]:
entropies = np.array(mni_report['ntk_entropy'])

print(f"We obtain one entropie for each label combination: {entropies.shape}")
print(f"The label combinations are: {ntk_combintaion_computation.label_combinations}")

In [None]:
mni_norm = {
    "I(0, 1)": 2* ( entropies[:, 0] + entropies[:, 1] - entropies[:, 3]) / (entropies[:, 0] + entropies[:, 1]),
    "I(0, 8)": 2* ( entropies[:, 0] + entropies[:, 2] - entropies[:, 4]) / (entropies[:, 0] + entropies[:, 2]),
    "I(1, 8)": 2* ( entropies[:, 1] + entropies[:, 2] - entropies[:, 5]) / (entropies[:, 1] + entropies[:, 2]),
}

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True)

# Plot the Entropies 
axs[0].plot(entropies[:, 0] , label="H(0)")
axs[0].plot(entropies[:, 1] , label="H(1)")
axs[0].plot(entropies[:, 2] , label="H(8)")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Entropy")
axs[0].legend()

for key, value in mni_norm.items():
    axs[1].plot(value, label=key)
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Normalized MI")
axs[1].legend()