# MNIST meets RND

In this tutorial, we go over how to apply random network distillation to non-standard network architectures, specifically, the convoltutional neural networks required to classify the MNIST dataset.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd as rnd

import tensorflow_datasets as tfds

import numpy as np
from flax import linen as nn
import optax

from jax.lib import xla_bridge

print(f"Using: {xla_bridge.get_backend().platform}")

### Making a data generator

The first thing we need to do is create a data generator for the problem.

In [None]:
data_generator = rnd.data.MNISTGenerator()

### Define the agent

In [None]:
class CustomModule(nn.Module):
    """
    Simple CNN module.
    """
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        return x

In [None]:
class ProductionModule(nn.Module):
    """
    Simple CNN module.
    """

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=300)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)

        return x

In [None]:
production_model = rnd.models.FlaxModel(
            flax_module=ProductionModule(),
            optimizer=optax.adam(learning_rate=0.01),
            loss_fn=rnd.loss_functions.CrossEntropyLoss(),
            input_shape=(1, 28, 28, 1),
            training_threshold=0.001,
            accuracy_fn=rnd.accuracy_functions.LabelAccuracy()
        )

In [None]:
_ = production_model.train_model(
    train_ds={"inputs": data_generator.train_ds["inputs"], "targets": data_generator.train_ds["targets"]},
    test_ds={"inputs": data_generator.test_ds["inputs"], "targets": data_generator.test_ds["targets"]},
    batch_size=32,
)

In [None]:
target = rnd.models.FlaxModel(
        flax_module=CustomModule(),
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

predictor = rnd.models.FlaxModel(
        flax_module=CustomModule(),
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=rnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(1, 28, 28, 1),
        training_threshold=0.001
    )

In [None]:
agent = rnd.agents.RND(
        point_selector=rnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=rnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=8,
    )

In [None]:
ds = agent.build_dataset(target_size=5, visualize=True)

In [None]:
data_generator.plot_image(data_list=np.array(agent.target_set))