# Custom Data Generator

Here we will demonstrate how to create a custom data generator.

In [6]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import pandas as pd
import numpy as np
import znrnd

import optax
from neural_tangents import stax

from jax.lib import xla_bridge

print(f"Using: {xla_bridge.get_backend().platform}")

Using: cpu


### Download the dataset

In [7]:
url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model Year', 'Origin']

raw_dataset = pd.read_csv(url, names=column_names,
                          na_values='?', comment='\t',
                          sep=' ', skipinitialspace=True)

### Post-process the data here

In [8]:
dataset = raw_dataset.copy()
dataset = dataset.dropna()
dataset['Origin'] = dataset['Origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
dataset = pd.get_dummies(dataset, columns=['Origin'], prefix='', prefix_sep='')


dataset = (dataset-dataset.mean())/dataset.std()

### Create the data generator

In [14]:
class MPGDataGenerator(znrnd.data.DataGenerator):
    """
    Data generator for the MPG dataset.
    """
    def __init__(self, dataset: pd.DataFrame):
        """
        Constructor for the data generator.
        
        Parameters
        ----------
        dataset
        """        
        train_ds = dataset.sample(frac=0.8, random_state=0)
        train_labels = train_ds.pop("MPG")
        test_ds = dataset.drop(train_ds.index)
        test_labels = test_ds.pop("MPG")
        
        self.train_ds = {"inputs": train_ds.to_numpy(), "targets": train_labels.to_numpy().reshape(-1, 1)}
        self.test_ds = {"inputs": test_ds.to_numpy(), "targets": test_labels.to_numpy().reshape(-1, 1)}
        
        self.data_pool = self.train_ds["inputs"]
        

In [15]:
data_generator = MPGDataGenerator(dataset)

### Create a model

In [16]:
model = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(1),
)

In [17]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.001),
            input_shape=(9,),
        )

### Create a training strategy

In [19]:
training_strategy = znrnd.training_strategies.SimpleTraining(
    model=ntk_network, 
    loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
    accuracy_fn=znrnd.accuracy_functions.LabelAccuracy(),
)

### Train the model

In [20]:
metrics = training_strategy.train_model(
    train_ds=data_generator.train_ds, test_ds=data_generator.test_ds, epochs=100, batch_size=32
)

Epoch: 100: 100%|██████████████████████████████████| 100/100 [00:12<00:00,  8.32batch/s, accuracy=1]


### Perform some data selection

In [21]:
rnd_stack = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
)

In [23]:
target = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        input_shape=(9,),
    )

predictor = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        input_shape=(9,),
    )

In [24]:
training_strategy = znrnd.training_strategies.SimpleTraining(
    model=predictor, 
    loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
)

In [25]:
agent = znrnd.agents.RND(
    training_strategy=training_strategy, 
    point_selector=znrnd.point_selection.GreedySelection(threshold=0.01),
    distance_metric=znrnd.distance_metrics.OrderNDifference(order=2),
    data_generator=data_generator,
    target_network=target,
    predictor_network=predictor,
    tolerance=8,
    )

In [26]:
ds = agent.build_dataset(10)

Epoch: 50: 100%|████████████████████████████████| 50/50 [00:00<00:00, 152.51batch/s, test_loss=7.23]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:00<00:00, 96.53batch/s, test_loss=5.01]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:00<00:00, 74.37batch/s, test_loss=2.95]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:00<00:00, 59.54batch/s, test_loss=0.921]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:01<00:00, 49.27batch/s, test_loss=0.606]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:02<00:00, 18.87batch/s, test_loss=0.824]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:04<00:00, 10.63batch/s, test_loss=0.91]
Epoch: 50: 100%|█████████████████████████████████| 50/50 [00:05<00:00,  8.75batch/s, test_loss=0.81]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:07<00:00,  7.01batch/s, test_loss=0.758]
Epoch: 50: 100%|████████████████████████████████| 50/50 [00:08<00:00,  6.05batch/s, test_lo

In [13]:
train_ds = {
    "inputs": np.take(data_generator.train_ds["inputs"], agent.target_indices, axis=0),
    "targets": np.take(data_generator.train_ds["targets"], agent.target_indices, axis=0)
}

In [14]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.1),
            loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
            input_shape=(9,),
            training_threshold=0.001
        )

In [15]:
rnd_metrics = ntk_network.train_model(train_ds, test_ds=data_generator.test_ds)

Epoch: 50: 100%|████████████████████████████████| 50/50 [00:08<00:00,  5.66batch/s, test_loss=0.391]
