# Custom Data Generator

Here we will demonstrate how to create a custom data generator.

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import pandas as pd
import numpy as np
import znrnd

import optax
from neural_tangents import stax

from jax.lib import xla_bridge

print(f"Using: {xla_bridge.get_backend().platform}")

Using: cpu


### Download the dataset

In [4]:
url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model Year', 'Origin']

raw_dataset = pd.read_csv(url, names=column_names,
                          na_values='?', comment='\t',
                          sep=' ', skipinitialspace=True)

### Post-process the data here

In [5]:
dataset = raw_dataset.copy()
dataset = dataset.dropna()
dataset['Origin'] = dataset['Origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
dataset = pd.get_dummies(dataset, columns=['Origin'], prefix='', prefix_sep='')


dataset = (dataset-dataset.mean())/dataset.std()

### Create the data generator

In [6]:
class MPGDataGenerator(znrnd.data.DataGenerator):
    """
    Data generator for the MPG dataset.
    """
    def __init__(self, dataset: pd.DataFrame):
        """
        Constructor for the data generator.
        
        Parameters
        ----------
        dataset
        """        
        train_ds = dataset.sample(frac=0.8, random_state=0)
        train_labels = train_ds.pop("MPG")
        test_ds = dataset.drop(train_ds.index)
        test_labels = test_ds.pop("MPG")
        
        self.train_ds = {"inputs": train_ds.to_numpy(), "targets": train_labels.to_numpy().reshape(-1, 1)}
        self.test_ds = {"inputs": test_ds.to_numpy(), "targets": test_labels.to_numpy().reshape(-1, 1)}
        
        self.data_pool = self.train_ds["inputs"]
        

In [7]:
data_generator = MPGDataGenerator(dataset)

### Create a model

In [8]:
model = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(1),
)

In [9]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.001),
            input_shape=(9,),
        )

### Create a training strategy

In [10]:
training_strategy = znrnd.training_strategies.SimpleTraining(
    model=ntk_network, 
    loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
    accuracy_fn=znrnd.accuracy_functions.LabelAccuracy(),
)

### Train the model

In [11]:
metrics = training_strategy.train_model(
    train_ds=data_generator.train_ds, test_ds=data_generator.test_ds, epochs=100, batch_size=32
)

Epoch: 100: 100%|██████████████████████████████████| 100/100 [00:03<00:00, 28.68batch/s, accuracy=1]


### Perform some data selection

In [12]:
rnd_stack = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
)

In [13]:
target = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        input_shape=(9,),
    )

predictor = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        input_shape=(9,),
    )

In [14]:
training_strategy = znrnd.training_strategies.SimpleTraining(
    model=predictor, 
    loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
)

In [15]:
agent = znrnd.agents.RND(
    training_strategy=training_strategy, 
    point_selector=znrnd.point_selection.GreedySelection(threshold=0.01),
    distance_metric=znrnd.distance_metrics.OrderNDifference(order=2),
    data_generator=data_generator,
    target_network=target,
    predictor_network=predictor,
    tolerance=8,
    )

In [16]:
ds = agent.build_dataset(10)

Epoch: 50: 100%|████████████████████████████████| 50/50 [00:00<00:00, 161.81batch/s, test_loss=6.28]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:00<00:00, 101.65batch/s, test_loss=4.7]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:00<00:00, 81.74batch/s, test_loss=3.15]
Epoch: 50: 100%|██████████████████████████████████| 50/50 [00:00<00:00, 65.09batch/s, test_loss=1.7]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:00<00:00, 50.88batch/s, test_loss=0.96]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:01<00:00, 46.68batch/s, test_loss=0.668]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:01<00:00, 40.07batch/s, test_loss=0.68]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:01<00:00, 36.39batch/s, test_loss=0.699]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:01<00:00, 30.94batch/s, test_loss=0.717]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:01<00:00, 29.63batch/s, test_lo

In [17]:
train_ds = {
    "inputs": np.take(data_generator.train_ds["inputs"], agent.target_indices, axis=0),
    "targets": np.take(data_generator.train_ds["targets"], agent.target_indices, axis=0)
}

In [19]:
production_model = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.1),
            input_shape=(9,),
        )

production_training = znrnd.training_strategies.SimpleTraining(
    model=production_model, 
    loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
)

In [20]:
rnd_metrics = production_training.train_model(train_ds, test_ds=data_generator.test_ds)

Epoch: 50: 100%|████████████████████████████████| 50/50 [00:01<00:00, 33.09batch/s, test_loss=0.526]
