# MNIST meets RND

In this tutorial, we go over how to apply random network distillation to non-standard network architectures, specifically, the convoltutional neural networks required to classify the MNIST dataset.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd as rnd

import tensorflow_datasets as tfds

import numpy as np
from flax import linen as nn
import optax
from plotly.subplots import make_subplots
import plotly.graph_objects as go



### Making a data generator

The first thing we need to do is create a data generator for the problem.

In [2]:
data_generator = rnd.data.MNISTGenerator()

### Define the agent

In [3]:
class CustomModule(nn.Module):
    """
    Simple CNN module.
    """
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        return x

In [4]:
class ProductionModule(nn.Module):
    """
    Simple CNN module.
    """

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        x = nn.log_softmax(x)

        return x

In [5]:
production_model = rnd.models.FlaxModel(
            flax_module=ProductionModule(),
            optimizer=optax.adam(learning_rate=0.1),
            loss_fn=rnd.loss_functions.CrossEntropyLoss(classes=10),
            input_shape=(1, 28, 28, 1),
            training_threshold=0.001,
            compute_accuracy=True
        )

In [6]:
production_model.train_model(
    train_ds={"inputs": data_generator.ds_train["image"], "targets": data_generator.ds_train["label"]},
    test_ds={"inputs": data_generator.ds_test["image"], "targets": data_generator.ds_test["label"]},
    batch_size=32,
)

Epoch: 50: 100%|███████████████████████████████████| 50/50 [02:00<00:00,  2.40s/batch, accuracy=0.1]


([2.3014981746673584,
  2.3084380626678467,
  2.309455633163452,
  2.3112237453460693,
  2.3098599910736084,
  2.3113555908203125,
  2.31083607673645,
  2.311269760131836,
  2.3115267753601074,
  2.3116278648376465,
  2.3117856979370117,
  2.311898946762085,
  2.312002420425415,
  2.3120880126953125,
  2.3121700286865234,
  2.3123295307159424,
  2.312422513961792,
  2.3124547004699707,
  2.3125,
  2.312549352645874,
  2.3125903606414795,
  2.312685251235962,
  2.313063144683838,
  2.3130977153778076,
  2.3130006790161133,
  2.3130128383636475,
  2.3130147457122803,
  2.313001871109009,
  2.312995433807373,
  2.312990427017212,
  2.312983989715576,
  2.312978744506836,
  2.3129732608795166,
  2.3129677772521973,
  2.3129630088806152,
  2.3129587173461914,
  2.3129541873931885,
  2.3129498958587646,
  2.31294584274292,
  2.3129420280456543,
  2.312938690185547,
  2.3129348754882812,
  2.312931776046753,
  2.3129284381866455,
  2.312925100326538,
  2.3129220008850098,
  2.3129191398620605

In [7]:
target = rnd.models.FlaxModel(
        flax_module=CustomModule(),
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

predictor = rnd.models.FlaxModel(
        flax_module=CustomModule(),
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

In [8]:
agent = rnd.agents.RND(
        point_selector=rnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=rnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=8,
    )

In [None]:
agent.build_dataset(visualize=True)

Epoch: 100: 100%|████| 100/100 [00:14<00:00,  6.95batch/s, test_loss={'loss': 6.73146729468499e-08}]
Epoch: 100: 100%|███| 100/100 [00:28<00:00,  3.48batch/s, test_loss={'loss': 7.361788334492303e-08}]
Epoch: 100: 100%|██| 100/100 [00:43<00:00,  2.30batch/s, test_loss={'loss': 5.7011699743471974e-15}]
Epoch: 100: 100%|███| 100/100 [00:54<00:00,  1.85batch/s, test_loss={'loss': 5.047224540166619e-15}]
Epoch: 100: 100%|██| 100/100 [00:52<00:00,  1.90batch/s, test_loss={'loss': 3.7239574827139556e-15}]
Epoch: 100: 100%|██| 100/100 [01:16<00:00,  1.31batch/s, test_loss={'loss': 3.1359655849882984e-07}]
Epoch: 86:  85%|████▎| 85/100 [01:12<00:12,  1.17batch/s, test_loss={'loss': 5.104852096593504e-09}]

In [None]:
data_generator.plot_image(data_list=np.array(agent.target_set))