# MNIST meets RND

In this tutorial, we go over how to apply random network distillation to non-standard network architectures, specifically, the convoltutional neural networks required to classify the MNIST dataset.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd as rnd

import tensorflow_datasets as tfds

import numpy as np
from flax import linen as nn
import optax



### Making a data generator

The first thing we need to do is create a data generator for the problem.

In [2]:
data_generator = rnd.data.MNISTGenerator()

### Define the agent

In [3]:
class CustomModule(nn.Module):
    """
    Simple CNN module.
    """
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        return x

In [4]:
class ProductionModule(nn.Module):
    """
    Simple CNN module.
    """

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)

        return x

In [5]:
production_model = rnd.models.FlaxModel(
            flax_module=ProductionModule(),
            optimizer=optax.adam(learning_rate=0.01),
            loss_fn=rnd.loss_functions.CrossEntropyLoss(classes=10),
            input_shape=(1, 28, 28, 1),
            training_threshold=0.001,
            compute_accuracy=True
        )

In [6]:
production_model.train_model(
    train_ds={"inputs": data_generator.ds_train["image"], "targets": data_generator.ds_train["label"]},
    test_ds={"inputs": data_generator.ds_test["image"], "targets": data_generator.ds_test["label"]},
    batch_size=32,
)

Epoch: 50: 100%|█████████████████████████████████| 50/50 [02:05<00:00,  2.51s/batch, accuracy=0.934]


([1.9081650972366333,
  0.7170539498329163,
  0.5145832896232605,
  0.4604935944080353,
  0.5124824643135071,
  0.3270696699619293,
  0.45221054553985596,
  0.4316408932209015,
  0.3962288200855255,
  0.42191997170448303,
  0.43806907534599304,
  0.43773353099823,
  0.4474071264266968,
  0.44737598299980164,
  0.45669522881507874,
  0.45106184482574463,
  0.45159268379211426,
  0.45254239439964294,
  0.4553583562374115,
  0.4582881033420563,
  0.46119967103004456,
  0.46404439210891724,
  0.4667452573776245,
  0.4694133400917053,
  0.47191354632377625,
  0.47429773211479187,
  0.476602703332901,
  0.47883060574531555,
  0.48097744584083557,
  0.48307761549949646,
  0.48522433638572693,
  0.4872837960720062,
  0.4892633557319641,
  0.4911554157733917,
  0.4930015802383423,
  0.494782030582428,
  0.49652528762817383,
  0.49817031621932983,
  0.49978744983673096,
  0.5013770461082458,
  0.5029512047767639,
  0.5044976472854614,
  0.5060179233551025,
  0.5075149536132812,
  0.5089411139488

In [7]:
target = rnd.models.FlaxModel(
        flax_module=CustomModule(),
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

predictor = rnd.models.FlaxModel(
        flax_module=CustomModule(),
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

In [8]:
agent = rnd.agents.RND(
        point_selector=rnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=rnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=8,
    )

In [None]:
agent.build_dataset(visualize=True)

Epoch: 100: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.35batch/s, test_loss=1e-7]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:26<00:00,  3.84batch/s, test_loss=2.78e-12]
Epoch: 100: 100%|███████████████████████████| 100/100 [00:32<00:00,  3.12batch/s, test_loss=1.59e-6]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:48<00:00,  2.08batch/s, test_loss=2.23e-15]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:03<00:00,  1.58batch/s, test_loss=9.45e-15]
Epoch: 16:  15%|████▎                        | 15/100 [00:11<00:52,  1.63batch/s, test_loss=4.64e-6]

In [None]:
data_generator.plot_image(data_list=np.array(agent.target_set))