In [1]:
import znrnd

import numpy as np
import optax
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from neural_tangents import stax

import matplotlib.pyplot as plt
import copy



In [2]:
data_generator = znrnd.data.MNISTGenerator(ds_size=1000)

Metal device set to: Apple M1


2022-05-15 17:18:54.196677: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [3]:
model = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Flatten(),
    stax.Dense(256),
    stax.Relu(),
    stax.Dense(10)
)

In [4]:
test_ds = {
    "inputs": data_generator.ds_test["image"],
    "targets": data_generator.ds_test["label"]
}

In [5]:
random_agent = znrnd.agents.RandomAgent(
    data_generator=data_generator
)

In [None]:
ds_sizes = [20, 50, 100, 150, 200, 300, 500]

# Start entropy data sets.
start_entropy = []
start_entropy_err = []

# Final entropy data sets.
final_entropy = []
final_entropy_err = []

# Final min loss
min_loss = []
min_loss_err = []

# Final train metrics
min_train_loss = []
min_train_loss_err = []

# Max acc
max_acc = []
max_acc_err = []


for item in ds_sizes:
    entropy_start = []
    entropy_end = []
    loss = []
    acc = []
    train_loss = []
    
    for _ in range(5):
        # Define a new model.
        random_model = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.001),
            loss_fn=znrnd.loss_functions.CrossEntropyLoss(classes=10, apply_softmax=False),
            input_shape=(1, 28, 28, 1),
            training_threshold=0.001
        )
        
        # Build the dataset.
        random_ds = random_agent.build_dataset(target_size=item)
        
        # Compute the start entropy.
        ntk = random_model.compute_ntk(x_i=random_ds, normalize=False)
        entropy_start.append(
            znrnd.analysis.EntropyAnalysis(
                matrix=ntk["empirical"]
            ).compute_von_neumann_entropy()
        )
        
        # Build the dataset.
        ds_random = {
            "inputs": np.take(data_generator.ds_train["image"], random_agent.target_indices, axis=0),
            "targets": np.take(data_generator.ds_train["label"], random_agent.target_indices, axis=0)
        }
        
        # Train the model.
        random_loss, random_acc, training_metrics = random_model.train_model(
            train_ds=ds_random, test_ds=test_ds, batch_size=10, epochs=500
        )
        train_metrics = [item["loss"] for item in training_metrics]
        # Compute the final entropy.
        ntk = random_model.compute_ntk(x_i=random_ds, normalize=False)
        entropy_end.append(
            znrnd.analysis.EntropyAnalysis(
                matrix=ntk["empirical"]
            ).compute_von_neumann_entropy()
        )
        # Update loss and accuracy arrays
        loss.append(np.min(random_loss))
        acc.append(np.max(random_acc))
        train_loss.append(np.min(train_metrics))
        
    # Update the stored arrays.
    start_entropy.append(np.mean(entropy_start))
    start_entropy_err.append(np.std(entropy_start) / np.sqrt(5))
    
    final_entropy.append(np.mean(entropy_end))
    final_entropy_err.append(np.std(entropy_end) / np.sqrt(5))
    
    min_loss.append(np.mean(loss))
    min_loss_err.append(np.std(loss) / np.sqrt(5))
    
    min_train_loss.append(np.mean(train_loss))
    min_loss_err.append(np.std(train_loss) / np.sqrt(5))
    
    max_acc.append(np.mean(acc))
    max_acc_err.append(np.std(acc) / np.sqrt(5))
        

Epoch: 500: 100%|██████████████| 500/500 [02:13<00:00,  3.75batch/s, accuracy=0.419, test_loss=4.67]
Epoch: 500: 100%|██████████████| 500/500 [02:18<00:00,  3.61batch/s, accuracy=0.487, test_loss=3.07]
Epoch: 500: 100%|██████████████| 500/500 [02:19<00:00,  3.58batch/s, accuracy=0.505, test_loss=4.45]
Epoch: 500: 100%|██████████████| 500/500 [02:03<00:00,  4.05batch/s, accuracy=0.455, test_loss=6.07]
Epoch: 500: 100%|██████████████| 500/500 [02:29<00:00,  3.34batch/s, accuracy=0.548, test_loss=4.01]
Epoch: 500: 100%|██████████████| 500/500 [04:00<00:00,  2.08batch/s, accuracy=0.711, test_loss=1.51]
Epoch: 500: 100%|██████████████| 500/500 [03:33<00:00,  2.34batch/s, accuracy=0.686, test_loss=3.04]
Epoch: 500: 100%|██████████████| 500/500 [03:26<00:00,  2.43batch/s, accuracy=0.631, test_loss=2.28]
