# Custom Data Generator

Here we will demonstrate how to create a custom data generator.

In [1]:
import pandas as pd
import numpy as np
import znrnd

import optax
from neural_tangents import stax

from jax.lib import xla_bridge

print(f"Using: {xla_bridge.get_backend().platform}")

Using: gpu


### Download the dataset

In [2]:
url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model Year', 'Origin']

raw_dataset = pd.read_csv(url, names=column_names,
                          na_values='?', comment='\t',
                          sep=' ', skipinitialspace=True)

### Post-process the data here

In [3]:
dataset = raw_dataset.copy()
dataset = dataset.dropna()
dataset['Origin'] = dataset['Origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
dataset = pd.get_dummies(dataset, columns=['Origin'], prefix='', prefix_sep='')


dataset = (dataset-dataset.mean())/dataset.std()

### Create the data generator

In [4]:
class MPGDataGenerator(znrnd.data.DataGenerator):
    """
    Data generator for the MPG dataset.
    """
    def __init__(self, dataset: pd.DataFrame):
        """
        Constructor for the data generator.
        
        Parameters
        ----------
        dataset
        """        
        train_ds = dataset.sample(frac=0.8, random_state=0)
        train_labels = train_ds.pop("MPG")
        test_ds = dataset.drop(train_ds.index)
        test_labels = test_ds.pop("MPG")
        
        self.train_ds = {"inputs": train_ds.to_numpy(), "targets": train_labels.to_numpy().reshape(-1, 1)}
        self.test_ds = {"inputs": test_ds.to_numpy(), "targets": test_labels.to_numpy().reshape(-1, 1)}
        
        self.data_pool = self.train_ds["inputs"]
        

In [5]:
data_generator = MPGDataGenerator(dataset)

### Create a model

In [6]:
model = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(1),
)

In [7]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.001),
            loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
            input_shape=(9,),
            training_threshold=0.001
        )

In [8]:
metrics = ntk_network.train_model(
    train_ds=data_generator.train_ds, test_ds=data_generator.test_ds, epochs=100, batch_size=32
)

Epoch: 100: 100%|█████████████████████████████| 100/100 [00:50<00:00,  1.96batch/s, test_loss=0.241]


### Perform some data selection

In [9]:
rnd_stack = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
)

In [10]:
target = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(9,),
        training_threshold=0.001
    )

predictor = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(9,),
        training_threshold=0.001
    )

In [11]:
agent = znrnd.agents.RND(
        point_selector=znrnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=znrnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=8,
    )

In [12]:
ds = agent.build_dataset(10)

Epoch: 100: 100%|████████████████████████████| 100/100 [00:01<00:00, 69.20batch/s, test_loss=0.0385]
Epoch: 110: 100%|██████████████████████████| 110/110 [00:00<00:00, 113.86batch/s, test_loss=0.00428]
Epoch: 121: 100%|██████████████████████████| 121/121 [00:00<00:00, 121.07batch/s, test_loss=9.24e-5]
Epoch: 100: 100%|█████████████████████████████| 100/100 [00:02<00:00, 38.10batch/s, test_loss=0.582]
Epoch: 110: 100%|█████████████████████████████| 110/110 [00:02<00:00, 47.04batch/s, test_loss=0.307]
Epoch: 121: 100%|████████████████████████████| 121/121 [00:07<00:00, 17.18batch/s, test_loss=0.0874]
Epoch: 133: 100%|████████████████████████████| 133/133 [00:14<00:00,  8.91batch/s, test_loss=0.0121]
Epoch: 146: 100%|██████████████████████████| 146/146 [00:15<00:00,  9.29batch/s, test_loss=0.000884]
Epoch: 100: 100%|████████████████████████████| 100/100 [00:16<00:00,  5.92batch/s, test_loss=0.0301]
Epoch: 110: 100%|███████████████████████████| 110/110 [00:18<00:00,  6.11batch/s, test_loss

In [13]:
train_ds = {
    "inputs": np.take(data_generator.train_ds["inputs"], agent.target_indices, axis=0),
    "targets": np.take(data_generator.train_ds["targets"], agent.target_indices, axis=0)
}

In [14]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.1),
            loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
            input_shape=(9,),
            training_threshold=0.001
        )

In [15]:
rnd_metrics = ntk_network.train_model(train_ds, test_ds=data_generator.test_ds)

Epoch: 50: 100%|████████████████████████████████| 50/50 [00:11<00:00,  4.30batch/s, test_loss=0.535]
