# NeurIPS CIFAR10 Study

In [1]:
# I am disabling the GPU here, feel free to comment these lines out if your
# Jax installation runs fine on your GPU.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd

from neural_tangents import stax
import tensorflow_datasets as tfds

import numpy as np
import optax
from plotly.subplots import make_subplots
import plotly.graph_objects as go



In [2]:
data_generator = znrnd.data.CIFAR10Generator()

2022-04-21 11:23:17.912733: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [3]:
model = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
    stax.Flatten(),
    stax.Dense(256)
)
model1 = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.AvgPool((2, 2), (2, 2)),
    stax.Conv(64, (3, 3)),
    stax.Relu(),
    stax.AvgPool((2, 2), (2, 2)),
    stax.Flatten(),
    stax.Dense(256)
)

In [6]:
target = znrnd.models.NTModel(
        nt_module=model,
        optimizer=optax.sgd(0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 32, 32, 3),
        training_threshold=0.001
    )

predictor = znrnd.models.NTModel(
        nt_module=model1,
        optimizer=optax.sgd(0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 32, 32, 3),
        training_threshold=0.001
    )

In [7]:
agent = znrnd.agents.RND(
        point_selector=znrnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=znrnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=15
)

In [None]:
target_set = agent.build_dataset(target_size=15, visualize=True)

Epoch: 100: 100%|███████████████████████████████| 100/100 [00:04<00:00, 21.85batch/s, test_loss=161]
Epoch: 110: 100%|█████████████████████████████| 110/110 [00:03<00:00, 28.90batch/s, test_loss=0.304]
Epoch: 121: 100%|██████████████████████████| 121/121 [00:04<00:00, 29.59batch/s, test_loss=0.000346]
Epoch: 100: 100%|███████████████████████████████| 100/100 [00:07<00:00, 14.07batch/s, test_loss=105]
Epoch: 110: 100%|██████████████████████████████| 110/110 [00:07<00:00, 15.22batch/s, test_loss=10.6]
Epoch: 121: 100%|█████████████████████████████| 121/121 [00:08<00:00, 14.82batch/s, test_loss=0.791]
Epoch: 133: 100%|████████████████████████████| 133/133 [00:09<00:00, 14.72batch/s, test_loss=0.0489]
Epoch: 146: 100%|███████████████████████████| 146/146 [00:09<00:00, 15.13batch/s, test_loss=0.00248]
Epoch: 160: 100%|███████████████████████████| 160/160 [00:10<00:00, 15.03batch/s, test_loss=9.84e-5]
Epoch: 100: 100%|██████████████████████████████| 100/100 [00:09<00:00, 10.09batch/s, test_l

Epoch: 449: 100%|███████████████████████████| 449/449 [01:37<00:00,  4.63batch/s, test_loss=0.00114]
Epoch: 493: 100%|██████████████████████████| 493/493 [01:44<00:00,  4.73batch/s, test_loss=0.000156]
Epoch: 100: 100%|██████████████████████████████| 100/100 [00:24<00:00,  4.05batch/s, test_loss=42.3]
Epoch: 110: 100%|██████████████████████████████| 110/110 [00:26<00:00,  4.18batch/s, test_loss=12.5]
Epoch: 121: 100%|██████████████████████████████| 121/121 [00:28<00:00,  4.19batch/s, test_loss=3.41]
Epoch: 133: 100%|██████████████████████████████| 133/133 [00:31<00:00,  4.19batch/s, test_loss=1.01]
Epoch: 146: 100%|█████████████████████████████| 146/146 [00:34<00:00,  4.18batch/s, test_loss=0.327]
Epoch: 160: 100%|█████████████████████████████| 160/160 [00:38<00:00,  4.17batch/s, test_loss=0.145]
Epoch: 176: 100%|████████████████████████████| 176/176 [00:42<00:00,  4.18batch/s, test_loss=0.0499]
Epoch: 193: 100%|████████████████████████████| 193/193 [00:45<00:00,  4.24batch/s, test_los

Epoch: 339: 100%|██████████████████████████████| 339/339 [01:42<00:00,  3.31batch/s, test_loss=1.18]
Epoch: 372: 100%|█████████████████████████████| 372/372 [01:52<00:00,  3.30batch/s, test_loss=0.249]
Epoch: 409: 100%|████████████████████████████| 409/409 [02:07<00:00,  3.22batch/s, test_loss=0.0536]
Epoch: 449: 100%|███████████████████████████| 449/449 [02:19<00:00,  3.22batch/s, test_loss=0.00975]
Epoch: 493: 100%|███████████████████████████| 493/493 [02:34<00:00,  3.20batch/s, test_loss=0.00231]
Epoch: 542: 100%|██████████████████████████| 542/542 [02:47<00:00,  3.23batch/s, test_loss=0.000498]
Epoch: 100: 100%|██████████████████████████████| 100/100 [00:33<00:00,  2.98batch/s, test_loss=29.2]
Epoch: 110: 100%|██████████████████████████████| 110/110 [00:36<00:00,  3.00batch/s, test_loss=15.2]
Epoch: 121: 100%|███████████████████████████████| 121/121 [00:40<00:00,  3.02batch/s, test_loss=7.6]
Epoch: 133: 100%|██████████████████████████████| 133/133 [00:44<00:00,  3.01batch/s, test_l

Epoch: 309: 100%|██████████████████████████████| 309/309 [01:59<00:00,  2.59batch/s, test_loss=15.1]
Epoch: 339: 100%|██████████████████████████████| 339/339 [02:10<00:00,  2.60batch/s, test_loss=4.88]
Epoch: 372: 100%|██████████████████████████████| 372/372 [02:23<00:00,  2.60batch/s, test_loss=1.94]
Epoch: 409: 100%|█████████████████████████████| 409/409 [02:37<00:00,  2.60batch/s, test_loss=0.867]
Epoch: 449: 100%|██████████████████████████████| 449/449 [02:52<00:00,  2.60batch/s, test_loss=0.43]
Epoch: 493: 100%|█████████████████████████████| 493/493 [03:09<00:00,  2.60batch/s, test_loss=0.144]
Epoch: 202:  37%|██████████▍                 | 201/542 [01:19<02:17,  2.48batch/s, test_loss=0.0929]