# Custom Data Generator

Here we will demonstrate how to create a custom data generator.

In [1]:
import pandas as pd
import numpy as np
import znrnd

import optax
from neural_tangents import stax



### Download the dataset

In [2]:
url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model Year', 'Origin']

raw_dataset = pd.read_csv(url, names=column_names,
                          na_values='?', comment='\t',
                          sep=' ', skipinitialspace=True)

### Post-process the data here

In [3]:
dataset = raw_dataset.copy()
dataset = dataset.dropna()
dataset['Origin'] = dataset['Origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
dataset = pd.get_dummies(dataset, columns=['Origin'], prefix='', prefix_sep='')


dataset = (dataset-dataset.mean())/dataset.std()

### Create the data generator

In [4]:
class MPGDataGenerator(znrnd.data.DataGenerator):
    """
    Data generator for the MPG dataset.
    """
    def __init__(self, dataset: pd.DataFrame):
        """
        Constructor for the data generator.
        
        Parameters
        ----------
        dataset
        """        
        train_ds = dataset.sample(frac=0.8, random_state=0)
        train_labels = train_ds.pop("MPG")
        test_ds = dataset.drop(train_ds.index)
        test_labels = test_ds.pop("MPG")
        
        self.train_ds = {"inputs": train_ds.to_numpy(), "targets": train_labels.to_numpy().reshape(-1, 1)}
        self.test_ds = {"inputs": test_ds.to_numpy(), "targets": test_labels.to_numpy().reshape(-1, 1)}
        
        self.data_pool = self.train_ds["inputs"]
        

In [5]:
data_generator = MPGDataGenerator(dataset)

### Create a model

In [28]:
model = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(1),
)

In [29]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.001),
            loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
            input_shape=(9,),
            training_threshold=0.001
        )

In [31]:
ntk_network.train_model(train_ds=data_generator.train_ds, test_ds=data_generator.test_ds, epochs=100)

Epoch: 100: 100%|█████████████████████████████| 100/100 [12:14<00:00,  7.35s/batch, test_loss=0.242]


([0.561866283416748,
  0.38587749004364014,
  0.29673609137535095,
  0.27801713347435,
  0.2691114544868469,
  0.2605889141559601,
  0.2593339681625366,
  0.25027135014533997,
  0.2422495186328888,
  0.23746685683727264,
  0.24060063064098358,
  0.24108919501304626,
  0.2361724078655243,
  0.2355002462863922,
  0.23490004241466522,
  0.23093287646770477,
  0.23214618861675262,
  0.23414099216461182,
  0.23743556439876556,
  0.2391740381717682,
  0.23580168187618256,
  0.23742584884166718,
  0.23707884550094604,
  0.23644202947616577,
  0.23477618396282196,
  0.23992504179477692,
  0.2389359176158905,
  0.2409382164478302,
  0.23398426175117493,
  0.23654945194721222,
  0.24029825627803802,
  0.23590072989463806,
  0.23860454559326172,
  0.24135839939117432,
  0.23728728294372559,
  0.2406376302242279,
  0.24267970025539398,
  0.24345408380031586,
  0.2416716367006302,
  0.24193577468395233,
  0.23871028423309326,
  0.2414756715297699,
  0.24544723331928253,
  0.2427956610918045,
  0.24

### Perform some data selection

In [11]:
rnd_stack = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
)

In [14]:
target = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(9,),
        training_threshold=0.001
    )

predictor = znrnd.models.NTModel(
        nt_module=rnd_stack,
        optimizer=optax.adam(learning_rate=0.001),
        loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
        input_shape=(9,),
        training_threshold=0.001
    )

In [15]:
agent = znrnd.agents.RND(
        point_selector=znrnd.point_selection.GreedySelection(threshold=0.01),
        distance_metric=znrnd.distance_metrics.OrderNDifference(order=2),
        data_generator=data_generator,
        target_network=target,
        predictor_network=predictor,
        tolerance=8,
    )

In [16]:
ds = agent.build_dataset(50)

Epoch: 100: 100%|█████████████████████████████| 100/100 [00:02<00:00, 41.98batch/s, test_loss=0.189]
Epoch: 110: 100%|████████████████████████████| 110/110 [00:02<00:00, 47.43batch/s, test_loss=0.0471]
Epoch: 121: 100%|███████████████████████████| 121/121 [00:02<00:00, 44.01batch/s, test_loss=0.00768]
Epoch: 133: 100%|██████████████████████████| 133/133 [00:02<00:00, 48.41batch/s, test_loss=0.000829]
Epoch: 100: 100%|████████████████████████████| 100/100 [00:04<00:00, 23.95batch/s, test_loss=0.0382]
Epoch: 110: 100%|███████████████████████████| 110/110 [00:04<00:00, 25.29batch/s, test_loss=0.00293]
Epoch: 121: 100%|██████████████████████████| 121/121 [00:04<00:00, 25.20batch/s, test_loss=0.000239]
Epoch: 100: 100%|███████████████████████████| 100/100 [00:05<00:00, 16.70batch/s, test_loss=0.00963]
Epoch: 110: 100%|██████████████████████████| 110/110 [00:06<00:00, 17.04batch/s, test_loss=0.000195]
Epoch: 100: 100%|███████████████████████████| 100/100 [00:07<00:00, 12.58batch/s, test_loss

Epoch: 100: 100%|██████████████████████████| 100/100 [00:36<00:00,  2.72batch/s, test_loss=0.000541]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:41<00:00,  2.42batch/s, test_loss=0.000406]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:43<00:00,  2.32batch/s, test_loss=0.000386]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:44<00:00,  2.26batch/s, test_loss=0.000426]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:45<00:00,  2.20batch/s, test_loss=0.000268]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:47<00:00,  2.09batch/s, test_loss=0.000189]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:48<00:00,  2.04batch/s, test_loss=0.000257]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:50<00:00,  1.98batch/s, test_loss=0.000337]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:52<00:00,  1.90batch/s, test_loss=0.000436]
Epoch: 100: 100%|██████████████████████████| 100/100 [00:54<00:00,  1.82batch/s, test_loss=

Epoch: 100: 100%|██████████████████████████| 100/100 [01:29<00:00,  1.12batch/s, test_loss=0.000348]
Epoch: 100: 100%|███████████████████████████| 100/100 [01:25<00:00,  1.18batch/s, test_loss=0.00034]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:29<00:00,  1.12batch/s, test_loss=0.000294]
Epoch: 100: 100%|███████████████████████████| 100/100 [01:32<00:00,  1.09batch/s, test_loss=0.00038]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:33<00:00,  1.07batch/s, test_loss=0.000366]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:35<00:00,  1.05batch/s, test_loss=0.000322]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:37<00:00,  1.03batch/s, test_loss=0.000357]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:39<00:00,  1.01batch/s, test_loss=0.000439]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:40<00:00,  1.01s/batch, test_loss=0.000423]
Epoch: 100: 100%|██████████████████████████| 100/100 [01:43<00:00,  1.04s/batch, test_loss=


RND agent report
----------------
Run time:  76.25 m
Size of point cloud: 314
Number of points chosen: 50
Seed points: None



In [32]:
train_ds = {
    "inputs": np.take(data_generator.train_ds["inputs"], agent.target_indices, axis=0),
    "targets": np.take(data_generator.train_ds["targets"], agent.target_indices, axis=0)
}

In [33]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.1),
            loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
            input_shape=(9,),
            training_threshold=0.001
        )

In [35]:
rnd_metrics = ntk_network.train_model(train_ds, test_ds=data_generator.test_ds)

Epoch: 50: 100%|████████████████████████████████| 50/50 [01:09<00:00,  1.38s/batch, test_loss=0.432]
