# S(t)

Measure the epoch dependence of the entropy of the NTK.

### Experiment
Perform RND with a model and select N points. Train the model on these N points and use the parameters of the trained model to perform NTK again and see if the points change. Along the way, check the entropy of the updated NTK and see how it has been effected by the training.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd as rnd

import tensorflow_datasets as tfds

import numpy as np
import optax
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from neural_tangents import stax

import matplotlib.pyplot as plt

from scipy import integrate



In [2]:
data_generator = rnd.data.MNISTGenerator(ds_size=1000)

Metal device set to: Apple M1


2022-05-15 17:22:27.580929: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [3]:
model = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Flatten(),
    stax.Dense(256),
    stax.Relu(),
    stax.Dense(10)
)

In [4]:
production_model = rnd.models.NTModel(
        nt_module=model,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.CrossEntropyLoss(classes=10, apply_softmax=False),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

In [5]:
train_ds = {
    "inputs": data_generator.ds_train["image"],
    "targets": data_generator.ds_train["label"]
}
test_ds = {
    "inputs": data_generator.ds_test["image"],
    "targets": data_generator.ds_test["label"]
}

In [None]:
entropy_data = {
    "5": {"infinite": [], "empirical": []}, 
    "10": {"infinite": [], "empirical": []},
    "15": {"infinite": [], "empirical": []},
    "20": {"infinite": [], "empirical": []},
    "25": {"infinite": [], "empirical": []},
    "30": {"infinite": [], "empirical": []},
    "35": {"infinite": [], "empirical": []},
    "40": {"infinite": [], "empirical": []},
    "50": {"infinite": [], "empirical": []},
    "100": {"infinite": [], "empirical": []},
}
subsets = [50, 100] # [5, 10, 15, 20, 25, 30, 35, 40, 50, 100]
metrics_array = []

for _ in range(100):
    for item in subsets:
        data_subset = data_generator.ds_train["image"][:item]
        ntk = production_model.compute_ntk(
            data_subset,
            normalize=True
        )
        entropy_inifinite = rnd.analysis.EntropyAnalysis(
            ntk["infinite"]
        ).compute_von_neumann_entropy(
            normalize=False
        )
        entropy_empirical = rnd.analysis.EntropyAnalysis(
            ntk["empirical"]
        ).compute_von_neumann_entropy(
            normalize=False
        )
        entropy_data[str(item)]["infinite"].append(entropy_inifinite)
        entropy_data[str(item)]["empirical"].append(entropy_empirical)
        
    metrics_array.append(production_model.train_model(
        train_ds=train_ds, test_ds=test_ds, batch_size=32, epochs=1
    ))


Epoch: 1: 100%|████████████████████| 1/1 [00:03<00:00,  3.30s/batch, accuracy=0.178, test_loss=2.24]
Epoch: 1: 100%|████████████████████| 1/1 [00:01<00:00,  1.13s/batch, accuracy=0.424, test_loss=2.12]
Epoch: 1: 100%|████████████████████| 1/1 [00:01<00:00,  1.14s/batch, accuracy=0.627, test_loss=1.91]
Epoch: 1: 100%|█████████████████████| 1/1 [00:01<00:00,  1.90s/batch, accuracy=0.73, test_loss=1.59]
Epoch: 1: 100%|█████████████████████| 1/1 [00:01<00:00,  1.17s/batch, accuracy=0.78, test_loss=1.22]
Epoch: 1: 100%|███████████████████| 1/1 [00:01<00:00,  1.94s/batch, accuracy=0.813, test_loss=0.925]
Epoch: 1: 100%|███████████████████| 1/1 [00:01<00:00,  2.00s/batch, accuracy=0.841, test_loss=0.739]
Epoch: 1: 100%|███████████████████| 1/1 [00:01<00:00,  1.99s/batch, accuracy=0.851, test_loss=0.627]
Epoch: 1: 100%|███████████████████| 1/1 [00:03<00:00,  3.28s/batch, accuracy=0.862, test_loss=0.557]
Epoch: 1: 100%|████████████████████| 1/1 [00:02<00:00,  2.15s/batch, accuracy=0.868, test_l

In [None]:
acc_arr = [item["accuracy"] for item in metrics_array]
loss_arr = [item["loss"] for item in metrics_array]

In [None]:
# colours = ["red", "blue", "green"]
scale = [np.log(5), np.log(10), np.log(30)]
fig, ax = plt.subplots()

ax2 = ax.twinx()

for i, item in enumerate(entropy_data):
    ax.plot(entropy_data[item]["empirical"], label=item)
#     ax.plot(entropy_data[item]["infinite"]), '-', c=colours[i])

ax2.plot(acc_arr, '.')
# plt.yscale("log")
ax.legend()
ax.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax.set_ylabel("Total Entropy")
plt.savefig("Entropy_vs_Training.pdf")
plt.show()

In [None]:
plt.plot(
    subsets, 
    [entropy_data[item]["empirical"][-1] for item in entropy_data],
    'o',
    label="empirical"
)
plt.plot(
    subsets, 
    [entropy_data[item]["infinite"][-1] for item in entropy_data],
    'o',
    label="inifinite"
)
plt.xlabel("Subset size")
plt.ylabel("Final entropy")
plt.legend()
plt.savefig("Final_Entropy_Subsets.pdf")
plt.show()

In [None]:
plt.plot(acc_arr, entropy_data["100"]["empirical"], '.')
plt.plot(acc_arr, entropy_data["50"]["empirical"], '.')
plt.plot(acc_arr, entropy_data["10"]["empirical"], '.')
plt.xlabel("Accuracy")
plt.ylabel("Empirical Entropy")
plt.savefig("Entropy_vs_Accuracy.pdf")
plt.show()

### Information transfer

In [None]:
plot_array = []
for value in entropy_data.values():
    data = value["empirical"][0] - value["empirical"][-1]
    plot_array.append(data)
    
plt.plot(subsets, plot_array, '.')

### Area under the curve

In [None]:
integral_data = []

for value in entropy_data.values():
    integral = integrate.trapz(value["empirical"])
    integral_data.append(integral)
    
plt.plot(subsets, integral_data, '.')
plt.show()

## Accuracy vs Entropy

In [None]:
subsets = [10, 20, 30, 50, 100, 200, 500, 800]
metrics_dict = {}

for item in subsets:
    data_subset = data_generator.ds_train["image"][:item]
    
    production_model = rnd.models.NTModel(
        nt_module=model,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.CrossEntropyLoss(classes=10, apply_softmax=False),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )
    
    ntk = production_model.compute_ntk(
            data_subset,
            normalize=False
        )
    entropy_empirical = rnd.analysis.EntropyAnalysis(
        ntk["empirical"]
    ).compute_von_neumann_entropy(
        normalize=False
    )
    
    metrics_dict[str(item)] = {}
    metrics_dict[str(item)]["entropy"] = [entropy_empirical]
    
    metrics = production_model.train_model(
        train_ds = {
            "inputs": data_generator.ds_train["image"][:item],
            "targets": data_generator.ds_train["label"][:item]
        }, 
        test_ds=test_ds, 
        batch_size=10, 
        epochs=200
    )
    ntk = production_model.compute_ntk(
            data_subset,
            normalize=True
        )
    entropy_empirical = rnd.analysis.EntropyAnalysis(
        ntk["empirical"]
    ).compute_von_neumann_entropy(
        normalize=False
    )
    
    metrics_dict[str(item)]["entropy"].append(entropy_empirical)
    metrics_dict[str(item)]["model"] = metrics
    print(metrics_dict)


Epoch: 200: 100%|██████████████| 200/200 [01:20<00:00,  2.48batch/s, accuracy=0.438, test_loss=2.55]


{'10': {'entropy': [DeviceArray(1.4299815-0.j, dtype=complex64), DeviceArray(1.0649279-0.j, dtype=complex64)], 'model': {'accuracy': 0.43800002336502075, 'loss': 2.549513816833496}}}


Epoch: 200: 100%|██████████████| 200/200 [01:29<00:00,  2.23batch/s, accuracy=0.528, test_loss=2.08]


{'10': {'entropy': [DeviceArray(1.4299815-0.j, dtype=complex64), DeviceArray(1.0649279-0.j, dtype=complex64)], 'model': {'accuracy': 0.43800002336502075, 'loss': 2.549513816833496}}, '20': {'entropy': [DeviceArray(1.6265097-0.j, dtype=complex64), DeviceArray(1.2852912-0.j, dtype=complex64)], 'model': {'accuracy': 0.527999997138977, 'loss': 2.0837535858154297}}}


Epoch: 200: 100%|██████████████| 200/200 [01:42<00:00,  1.96batch/s, accuracy=0.574, test_loss=2.21]


{'10': {'entropy': [DeviceArray(1.4299815-0.j, dtype=complex64), DeviceArray(1.0649279-0.j, dtype=complex64)], 'model': {'accuracy': 0.43800002336502075, 'loss': 2.549513816833496}}, '20': {'entropy': [DeviceArray(1.6265097-0.j, dtype=complex64), DeviceArray(1.2852912-0.j, dtype=complex64)], 'model': {'accuracy': 0.527999997138977, 'loss': 2.0837535858154297}}, '30': {'entropy': [DeviceArray(1.7681414-0.j, dtype=complex64), DeviceArray(1.2330437-0.j, dtype=complex64)], 'model': {'accuracy': 0.5740000009536743, 'loss': 2.213181972503662}}}


Epoch: 200: 100%|███████████████| 200/200 [02:13<00:00,  1.49batch/s, accuracy=0.735, test_loss=2.4]


{'10': {'entropy': [DeviceArray(1.4299815-0.j, dtype=complex64), DeviceArray(1.0649279-0.j, dtype=complex64)], 'model': {'accuracy': 0.43800002336502075, 'loss': 2.549513816833496}}, '20': {'entropy': [DeviceArray(1.6265097-0.j, dtype=complex64), DeviceArray(1.2852912-0.j, dtype=complex64)], 'model': {'accuracy': 0.527999997138977, 'loss': 2.0837535858154297}}, '30': {'entropy': [DeviceArray(1.7681414-0.j, dtype=complex64), DeviceArray(1.2330437-0.j, dtype=complex64)], 'model': {'accuracy': 0.5740000009536743, 'loss': 2.213181972503662}}, '50': {'entropy': [DeviceArray(1.9276254-0.j, dtype=complex64), DeviceArray(1.2277381-0.j, dtype=complex64)], 'model': {'accuracy': 0.7350000143051147, 'loss': 2.4009249210357666}}}


Epoch: 200: 100%|██████████████| 200/200 [03:27<00:00,  1.04s/batch, accuracy=0.801, test_loss=1.03]


{'10': {'entropy': [DeviceArray(1.4299815-0.j, dtype=complex64), DeviceArray(1.0649279-0.j, dtype=complex64)], 'model': {'accuracy': 0.43800002336502075, 'loss': 2.549513816833496}}, '20': {'entropy': [DeviceArray(1.6265097-0.j, dtype=complex64), DeviceArray(1.2852912-0.j, dtype=complex64)], 'model': {'accuracy': 0.527999997138977, 'loss': 2.0837535858154297}}, '30': {'entropy': [DeviceArray(1.7681414-0.j, dtype=complex64), DeviceArray(1.2330437-0.j, dtype=complex64)], 'model': {'accuracy': 0.5740000009536743, 'loss': 2.213181972503662}}, '50': {'entropy': [DeviceArray(1.9276254-0.j, dtype=complex64), DeviceArray(1.2277381-0.j, dtype=complex64)], 'model': {'accuracy': 0.7350000143051147, 'loss': 2.4009249210357666}}, '100': {'entropy': [DeviceArray(2.2563639-0.j, dtype=complex64), DeviceArray(1.3267068-0.j, dtype=complex64)], 'model': {'accuracy': 0.8010000586509705, 'loss': 1.0252071619033813}}}


Epoch: 200: 100%|█████████████| 200/200 [05:47<00:00,  1.74s/batch, accuracy=0.858, test_loss=0.953]


{'10': {'entropy': [DeviceArray(1.4299815-0.j, dtype=complex64), DeviceArray(1.0649279-0.j, dtype=complex64)], 'model': {'accuracy': 0.43800002336502075, 'loss': 2.549513816833496}}, '20': {'entropy': [DeviceArray(1.6265097-0.j, dtype=complex64), DeviceArray(1.2852912-0.j, dtype=complex64)], 'model': {'accuracy': 0.527999997138977, 'loss': 2.0837535858154297}}, '30': {'entropy': [DeviceArray(1.7681414-0.j, dtype=complex64), DeviceArray(1.2330437-0.j, dtype=complex64)], 'model': {'accuracy': 0.5740000009536743, 'loss': 2.213181972503662}}, '50': {'entropy': [DeviceArray(1.9276254-0.j, dtype=complex64), DeviceArray(1.2277381-0.j, dtype=complex64)], 'model': {'accuracy': 0.7350000143051147, 'loss': 2.4009249210357666}}, '100': {'entropy': [DeviceArray(2.2563639-0.j, dtype=complex64), DeviceArray(1.3267068-0.j, dtype=complex64)], 'model': {'accuracy': 0.8010000586509705, 'loss': 1.0252071619033813}}, '200': {'entropy': [DeviceArray(2.3037105-0.j, dtype=complex64), DeviceArray(1.3987174-0.j

In [None]:
metrics_dict