# NeurIPS CIFAR10 Study

In [10]:
# I am disabling the GPU here, feel free to comment these lines out if your
# Jax installation runs fine on your GPU.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd as rnd

from neural_tangents import stax
import tensorflow_datasets as tfds

import numpy as np
import optax
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [11]:
data_generator = rnd.data.CIFAR10Generator()

In [12]:
model = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Flatten(),
    stax.Dense(256)
)
model1 = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool((2, 2), (2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool((2, 2), (2, 2)),
    stax.Flatten(),
    stax.Dense(256)
)

In [13]:
target = rnd.models.NTModel(
        nt_module=model,
        optimizer=optax.sgd(0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 32, 32, 3),
        training_threshold=0.001
    )

predictor = rnd.models.NTModel(
        nt_module=model1,
        optimizer=optax.sgd(0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 32, 32, 3),
        training_threshold=0.001
    )

In [14]:
agent = rnd.agents.RND(
        point_selector=znrnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=znrnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=15
)

In [15]:
def run_experiment(data_set_size: int, ensembling: bool = False, ensembles: int = 10):
    """
    Run an experiment for a specific datasrndsize.
    
    Parameters
    ----------
    data_set_size : int
            Size of the dataset to produce
    ensembling : bool (default=False)
            If true, the experiment is run several times to produce an error estimate
    ensembles : int
            Number of ensembles to use in the averaging.
            
    Returns
    -------
    entropy : dict
            A dictionary of the computed entropy:
            e.g {"rnd": 0.68, "random": 0.41, "approximate_maximum": 0.84}
    eigenvalues : dict
            Dictionary of eigenvalues
            e.g {"rnd": np.array(), "random": np.array(), "approximate_maximum": np.array()}

    """
    # Turnoff averaging if required.
    if not ensembling:
        ensembles = 1
    
    rnd_entropy_arr = []
    random_entropy_arr = []
    apr_max_entropy_arr = []
    
    rnd_eig_arr = []
    random_eig_arr = []
    apr_max_eig_arr = []
    
    rnd_losses = []
    random_losses = []
    apr_max_losses = []
    
    for i in range(ensembles):
    
        # Define the models
        target = rnd.models.NTModel(
            nt_module=model,
            optimizer=optax.sgd(0.001),
            loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
            input_shape=(1, 32, 32, 3),
            training_threshold=0.001
        )

        predictor = rnd.models.NTModel(
            nt_module=model1,
            optimizer=optax.sgd(0.001),
            loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
            input_shape=(1, 32, 32, 3),
            training_threshold=0.001
        )

        # Define the agents for a fresh run.
        rnd_agent = rnd.agents.RND(
            point_selector=rnd.point_selection.GreedySelection(threshold=0.01),
            distance_metric=rnd.distance_metrics.OrderNDifference(order=2),
            data_generator=data_generator,
            target_network=target,
            predictor_network=predictor,
            tolerance=15
        )
        rnd_agent.target_set = []
        rnd_agent.target_indices = []
        
        random_agent = rnd.agents.RandomAgent(data_generator=data_generator)
        approximate_max_agent = rnd.agents.ApproximateMaximumEntropy(
            target_network=target, 
            data_generator=data_generator,
            samples=10,  # How many sets it produces in the test. Takes the one with max entropy.
        )

        # Compute the sets
        rnd_set = rnd_agent.build_dataset(target_size=data_set_size, visualize=False)
        random_set = random_agent.build_dataset(target_size=data_set_size, visualize=False)    
        apr_max_set = approximate_max_agent.build_dataset(
            target_size=data_set_size, visualize=False
        )

        # Compute NTK for each set
        rnd_ntk = target.compute_ntk(x_i=rnd_set)["empirical"]
        random_ntk = target.compute_ntk(x_i=random_set)["empirical"]
        apr_max_ntk = target.compute_ntk(x_i=apr_max_set)["empirical"]


        # Compute the entropy of each set
        rnd_entropy = rnd.analysis.EntropyAnalysis(matrix=rnd_ntk).compute_von_neumann_entropy()
        random_entropy = rnd.analysis.EntropyAnalysis(matrix=random_ntk).compute_von_neumann_entropy()
        apr_max_entropy = rnd.analysis.EntropyAnalysis(matrix=apr_max_ntk).compute_von_neumann_entropy()


        # Compute eigenvalues
        rnd_eigval = rnd.analysis.EigenSpaceAnalysis(matrix=rnd_ntk).compute_eigenvalues()
        random_eigval = rnd.analysis.EigenSpaceAnalysis(matrix=random_ntk).compute_eigenvalues()
        apr_max_eigval = rnd.analysis.EigenSpaceAnalysis(matrix=rnd_ntk).compute_eigenvalues()
        
        rnd_entropy_arr.append(rnd_entropy)
        random_entropy_arr.append(random_entropy)
        apr_max_entropy_arr.append(apr_max_entropy)
        
        rnd_eig_arr.append(rnd_eigval)
        random_eig_arr.append(random_eigval)
        apr_max_eig_arr.append(apr_max_eigval)
        
        # Train production model
        rnd_production = rnd.models.FlaxModel(
            flax_module=CustomModule(),
            optimizer=optax.adam(learning_rate=0.001),
            loss_fn=rnd.loss_functions.CrossEntropyLoss(classes=10),
            input_shape=(1, 32, 32, 3),
            training_threshold=0.001
        )
        
        random_production = rnd.models.FlaxModel(
            flax_module=CustomModule(),
            optimizer=optax.adam(learning_rate=0.001),
            loss_fn=rnd.loss_functions.CrossEntropyLoss(classes=10),
            input_shape=(1, 32, 32, 3),
            training_threshold=0.001
        )
        
        apr_max_production = rnd.models.FlaxModel(
            flax_module=CustomModule(),
            optimizer=optax.adam(learning_rate=0.001),
            loss_fn=rnd.loss_functions.CrossEntropyLoss(classes=10),
            input_shape=(1, 32, 32, 3),
            training_threshold=0.001
        )
        
        
        rnd_training_ds = {
            "inputs": np.take(data_generator.ds_train["image"], rnd_agent.target_indices, axis=0),
            "targets": np.take(data_generator.ds_train["label"], rnd_agent.target_indices, axis=0)
        }
        random_training_ds = {
            "inputs": np.take(data_generator.ds_train["image"], random_agent.target_indices, axis=0),
            "targets": np.take(data_generator.ds_train["label"], random_agent.target_indices, axis=0)
        }
        apr_max_training_ds = {
            "inputs": np.take(data_generator.ds_train["image"], approximate_max_agent.target_indices, axis=0),
            "targets": np.take(data_generator.ds_train["label"], approximate_max_agent.target_indices, axis=0)
        }
        
        test_ds = {
            "inputs": data_generator.ds_test["image"],
            "targets": data_generator.ds_test["label"]
        }
        
        rnd_losses.append(
            rnd_production.train_model(train_ds=rnd_training_ds, test_ds=test_ds)
        )
        random_losses.append(
            random_production.train_model(train_ds=random_training_ds, test_ds=test_ds)
        )
        apr_max_losses.append(apr_max_production.train_model(train_ds=apr_max_training_ds, test_ds=test_ds))
        
        
        del rnd_agent
        del random_agent
        del approximate_max_agent
    
    
    # Get mean and uncertainty.
    rnd_entropy_arr = np.array(rnd_entropy_arr)
    random_entropy_arr = np.array(random_entropy_arr)
    apr_max_entropy_arr = np.array(apr_max_entropy_arr)
    
    rnd_eig_arr = np.array(rnd_eig_arr)
    random_eig_arr = np.array(random_eig_arr)
    apr_max_eig_arr = np.array(apr_max_eig_arr)
    
    rnd_losses = np.array(rnd_losses)
    random_losses = np.array(random_losses)
    apr_max_losses = np.array(apr_max_losses)
    
    rnd_entropy = np.array(
        [np.mean(rnd_entropy_arr), np.std(rnd_entropy_arr) / np.sqrt(ensembles)]
    )
    random_entropy = np.array(
        [np.mean(random_entropy_arr), np.std(random_entropy_arr) / np.sqrt(ensembles)]
    )
    apr_max_entropy = np.array(
        [np.mean(apr_max_entropy_arr), np.std(apr_max_entropy_arr) / np.sqrt(ensembles)]
    )

    rnd_eigval = np.array(
        [np.mean(rnd_eig_arr, axis=0), np.std(rnd_eig_arr, axis=0) / np.sqrt(ensembles)]
    )
    random_eigval = np.array(
        [np.mean(random_eig_arr, axis=0), np.std(random_eig_arr, axis=0) / np.sqrt(ensembles)]
    )
    apr_max_eigval = np.array(
        [np.mean(apr_max_eig_arr, axis=0), np.std(apr_max_eig_arr, axis=0) / np.sqrt(ensembles)]
    )
    
    rnd_loss = np.array(
        [np.mean(rnd_losses, axis=0), np.std(rnd_losses, axis=0) / np.sqrt(ensembles)]
    )
    random_loss = np.array(
        [np.mean(random_losses, axis=0), np.std(random_losses, axis=0) / np.sqrt(ensembles)]
    )
    apr_max_loss = np.array(
        [np.mean(apr_max_losses, axis=0), np.std(apr_max_losses, axis=0) / np.sqrt(ensembles)]
    )
    
    entropy = {"rnd": rnd_entropy, "random": random_entropy, "approximate_maximum": apr_max_entropy}
    eigenvalues = {"rnd": rnd_eigval, "random": random_eigval, "approximate_maximum": apr_max_eigval}
    losses = {"rnd": rnd_loss, "random": random_loss, "approximate_maximum": apr_max_loss}
    
    return entropy, eigenvalues, losses

In [None]:
run_experiment(3, ensembling=True, ensembles=2)

Epoch: 100: 100%|████████████████████████████████| 100/100 [00:03<00:00, 29.18batch/s, test_loss=60]
Epoch: 110: 100%|████████████████████████████| 110/110 [00:03<00:00, 29.47batch/s, test_loss=0.0248]
Epoch: 121: 100%|███████████████████████████| 121/121 [00:04<00:00, 28.89batch/s, test_loss=5.35e-6]

The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.


The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.


The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.

Epoch: 100: 100%|██████████████████████████████| 100/100 [00:06<00:00, 15.51batch/s, test_loss=94.7]
Epoch: 110: 100%|██████████████████████████████| 110/110 [00:07<00:00, 14.79batch/s, test_loss=6.16]
Epoch: 121: 100%|█████████████████████████████| 121/121 [00:07<00:00,