# Simple Random Network Distillation

### Imports
The first step to using ZnRND is to import the package as you would any other Python library.

In this case we will also import some visualization libraries we can use along the way. Note that, in this case, we have turned off warnings from TF as the very nature of RND, i.e, a lot of re-training, triggers a lot of graph builds and therefore, a lot of TF complaints.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import znrnd 

from flax import linen as nn
import optax
import matplotlib.pyplot as plt
import numpy as np

from jax.lib import xla_bridge

print(f"Using: {xla_bridge.get_backend().platform}")

### Data pools

In ZnRND we use a data generator to store and handle the passing of data. This makes it easy to perform operations element-wise on data as well as theoretically to make the whole process compilable.

For this example we want to use points on a lattice using a built-in generator.

In [None]:
data_generator = znrnd.data.PointsOnLattice()
data_generator.build_pool(x_points=10, y_points=10)

Just for fun let's look at this data to see what we are working with.

In [None]:
x = data_generator.data_pool[:, 0]
y = data_generator.data_pool[:, 1]

plt.plot(x, y, '.')
plt.show()

So we have some square lattice of points on which to peform distillation. So what will distillation achieve on a grid like this? It should pick points seperated by N lattice spacing depending on what metric we are using.

### Defining the models

Now we need to build our target and predictor models using the build-in ZnRND models.

In [None]:
layer_stack = [nn.Dense(features=12), nn.sigmoid, nn.Dense(features=12), nn.sigmoid, nn.Dense(features=12)]

In [None]:
target = znrnd.models.FlaxModel(
        layer_stack=layer_stack,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(2, ),
        training_threshold=0.001
    )

In [None]:
predictor = znrnd.models.FlaxModel(
        layer_stack=layer_stack,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(2, ),
        training_threshold=0.001
    )

### Defining the agent

Finally we can define the actualy RND agent which will perform the data selection procedure.

In [None]:
agent = znrnd.agents.RND(
        point_selector=znrnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=znrnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=8,
    )

### Running rnd

With everything established it is now very simple to run the RND procedure

In [None]:
agent.build_dataset(target_size=5, visualize=True)

### Visualizing the results

So let's see how well the method performed.

In [None]:
target_set = np.array(agent.target_set)

x = data_generator.data_pool[:, 0]
y = data_generator.data_pool[:, 1]

x_selected = target_set[:, 0]
y_selected = target_set[:, 1]

plt.plot(x, y, '.', label='Input data')
plt.plot(x_selected, y_selected, 'x', label='Selected Points')
plt.legend(bbox_to_anchor=(1, 1))
plt.show()

So we can see that the RND method was able to pick three points quite far away from one another using these parameters.