# NeurIPS MNIST Study

In this notebook, we use the Neural Tangents network in ZnRND to perform the same MNIST learning task but with a more a deeper look into the entropy arguments surrounding NTK

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd as rnd

from neural_tangents import stax
import tensorflow_datasets as tfds

import numpy as np
import optax
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [None]:
class MNISTGenerator(rnd.data.DataGenerator):
    """
    Data generator for MNIST datasets
    """
    def __init__(self):
        """
        Constructor for the MNIST generator class.
        """
        self.ds_train, self.ds_test = tfds.as_numpy(
            tfds.load(
                'mnist:3.*.*', 
                split=['train[:%d]' % 500, 'test[:%d]' % 500], 
                batch_size=-1
            )
        )
        self.data_pool = self.ds_train['image'].astype(float)
    
    def _process_data(self, data_chunk):
        """
        Flatten the images and one-hot encode the labels.
        """  
        image, label = data_chunk['image'], data_chunk['label']

        samples = image.shape[0]
        image = np.array(np.reshape(image, (samples, -1)), dtype=np.float32)
        image = (image - np.mean(image)) / np.std(image)
        label = np.eye(10)[label]

        return {'image': image, 'label': label}
    
    def plot_image(self, indices: list = None, data_list: list = None):
        """
        Plot a single image from the training dataset.
        
        Parameters
        ----------
        indices : list (None)
        data_list : list (None)
        """
        if indices is not None:
            data_length = len(indices)
            data_source = self.ds_train["image"][indices]
        elif data_list is not None:
            data_length = len(data_list)
            data_source = data_list
        else:
            raise TypeError("No valid data provided")
        
        if data_length <= 4:
            columns = data_length
            rows = 1
        else:
            columns = 4
            rows = int(np.ceil(data_length / 4))
            
        fig = make_subplots(rows=rows, cols=columns)
        
        img_counter = 0
        for i in range(1, rows + 1):
            for j in range(1, columns + 1):
                if indices is not None:
                    data = self.ds_train["image"][img_counter].reshape(28, 28)
                else:
                    data = data_list[img_counter].reshape(28, 28)
                fig.add_trace(go.Heatmap(z=data), row=i, col=j)
                if img_counter == len(data_source) - 1:
                    break
                else:
                    img_counter += 1
                
        fig.show()

In [None]:
data_generator = MNISTGenerator()

In [None]:
data_generator.plot_image([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])

In [None]:
model = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Flatten(),
    stax.Dense(256)
)
model1 = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool((2, 2), (2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool((2, 2), (2, 2)),
    stax.Flatten(),
    stax.Dense(256)
)

In [None]:
target = rnd.models.NTModel(
        nt_module=model,
        optimizer=optax.sgd(0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

predictor = rnd.models.NTModel(
        nt_module=model1,
        optimizer=optax.sgd(0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

In [None]:
def run_experiment(data_set_size: int, ensembling: bool = False):
    """
    Run an experiment for a specific datasrndsize.
    
    Parameters
    ----------
    data_set_size : int
            Size of the dataset to produce
    ensembling : bool (default=False)
            If true, the experiment is run several times to produce an error estimate
            
    Returns
    -------
    entropy : dict
            A dictionary of the computed entropy:
            e.g {"rnd": 0.68, "random": 0.41, "approximate_maximum": 0.84}
    eigenvalues : dict
            Dictionary of eigenvalues
            e.g {"rnd": np.array(), "random": np.array(), "approximate_maximum": np.array()}

    """
    # Define the agents for a fresh run.
    rnd_agent = rnd.agents.RND(
        point_selector=rnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=rnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=15
    )
    random_agent = rnd.agents.RandomAgent(data_generator=data_generator)
    approximate_max_agent = rnd.agents.ApproximateMaximumEntropy(
        target_network=target, 
        data_generator=data_generator,
        samples=10,  # How many sets it produces in the test. Takes the one with max entropy.
    )
    
    # Compute the sets
    rnd_set = rnd_agent.build_dataset(target_size=data_set_size, visualize=False)
    random_set = random_agent.build_dataset(target_size=data_set_size, visualize=False)    
    apr_max_set = approximate_max_agent.build_dataset(
        target_size=data_set_size, visualize=False
    )
    
    # Compute NTK for each set
    rnd_ntk = target.compute_ntk(x_i=rnd_set)["empirical"]
    random_ntk = target.compute_ntk(x_i=rnd_set)["empirical"]
    apr_max_ntk = target.compute_ntk(x_i=rnd_set)["empirical"]

    
    # Compute the entropy of each set
    rnd_entropy = rnd.analysis.EntropyAnalysis(matrix=rnd_ntk).compute_von_neumann_entropy()
    random_entropy = rnd.analysis.EntropyAnalysis(matrix=random_ntk).compute_von_neumann_entropy()
    apr_max_entropy = rnd.analysis.EntropyAnalysis(matrix=apr_max_ntk).compute_von_neumann_entropy()
    
    entropy = {"rnd": rnd_entropy, "random": random_entropy, "approximate_maximum": apr_max_entropy}
    
    # Compute eigenvalues
    rnd_eigval = rnd.analysis.EigenSpaceAnalysis(matrix=rnd_ntk).compute_eigenvalues()
    random_eigval = rnd.analysis.EigenSpaceAnalysis(matrix=random_ntk).compute_eigenvalues()
    apr_max_eigval = rnd.analysis.EigenSpaceAnalysis(matrix=rnd_ntk).compute_eigenvalues()
    
    eigenvalues = {"rnd": rnd_eigval, "random": random_eigval, "approximate_maximum": apr_max_eigval}
   
    # Train production model
    
    
    return entropy, eigenvalues
    

## Analysis

In [None]:
test_exp = run_experiment(5)

In [None]:
np.take(data_generator.data_pool, [1, 5, 3, 9, 10], axis=0)