# MNIST meets RND

In this tutorial, we go over how to apply random network distillation to non-standard network architectures, specifically, the convoltutional neural networks required to classify the MNIST dataset.

In [44]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import znrnd as rnd

import tensorflow_datasets as tfds

import numpy as np

import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

### Making a data generator

The first thing we need to do is create a data generator for the problem.

In [100]:
class MNISTGenerator(rnd.data.DataGenerator):
    """
    Data generator for MNIST datasets
    """
    def __init__(self):
        """
        Constructor for the MNIST generator class.
        """
        self.ds_train, self.ds_test = tfds.as_numpy(
            tfds.load(
                'mnist:3.*.*', 
                split=['train[:%d]' % 500, 'test[:%d]' % 500], 
                batch_size=-1
            )
        )
        self.data_pool = self._process_data(self.ds_train)['image']
    
    def _process_data(self, data_chunk):
        """
        Flatten the images and one-hot encode the labels.
        """  
        image, label = data_chunk['image'], data_chunk['label']

        samples = image.shape[0]
        image = np.array(np.reshape(image, (samples, -1)), dtype=np.float32)
        image = (image - np.mean(image)) / np.std(image)
        label = np.eye(10)[label]

        return {'image': image, 'label': label}
    
    def plot_image(self, indices: list):
        """
        Plot a single image from the training dataset.
        """
        if len(indices) <= 4:
            columns = len(indices)
            rows = 1
        else:
            columns = 4
            rows = int(np.ceil(len(indices) / 4))
            
        fig = make_subplots(rows=rows, cols=columns)
        
        img_counter = 0
        for i in range(1, rows + 1):
            for j in range(1, columns + 1):
                data = self.ds_train["image"][img_counter].reshape(28, 28)
                fig.add_trace(go.Heatmap(z=data), row=i, col=j)
                img_counter += 1
                
        fig.show()


In [101]:
data_generator = MNISTGenerator()

In [104]:
data_generator.plot_image([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])