# Custom Data Generator

Here we will demonstrate how to create a custom data generator.

In [1]:
import pandas as pd
import numpy as np
import znrnd

import optax
from neural_tangents import stax



### Download the dataset

In [2]:
url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model Year', 'Origin']

raw_dataset = pd.read_csv(url, names=column_names,
                          na_values='?', comment='\t',
                          sep=' ', skipinitialspace=True)

### Post-process the data here

In [3]:
dataset = raw_dataset.copy()
dataset = dataset.dropna()
dataset['Origin'] = dataset['Origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
dataset = pd.get_dummies(dataset, columns=['Origin'], prefix='', prefix_sep='')


dataset = (dataset-dataset.mean())/dataset.std()

### Create the data generator

In [4]:
class MPGDataGenerator(znrnd.data.DataGenerator):
    """
    Data generator for the MPG dataset.
    """
    def __init__(self, dataset: pd.DataFrame):
        """
        Constructor for the data generator.
        
        Parameters
        ----------
        dataset
        """        
        train_ds = dataset.sample(frac=0.8, random_state=0)
        train_labels = train_ds.pop("MPG")
        test_ds = dataset.drop(train_ds.index)
        test_labels = test_ds.pop("MPG")
        
        self.train_ds = {"inputs": train_ds.to_numpy(), "targets": train_labels.to_numpy()}
        self.test_ds = {"inputs": test_ds.to_numpy(), "targets": test_labels.to_numpy()}
        
        self.data_pool = self.train_ds["inputs"]
        

In [5]:
data_generator = MPGDataGenerator(dataset)

### Create a model

In [7]:
model = stax.serial(
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(32),
    stax.Relu(),
    stax.Dense(1),
)

In [10]:
ntk_network = znrnd.models.NTModel(
            nt_module=model,
            optimizer=optax.adam(learning_rate=0.001),
            loss_fn=znrnd.loss_functions.LPNormLoss(order=2),
            input_shape=(1,9),
            training_threshold=0.001
        )

In [None]:
ntk_network.train_model(train_ds=data_generator.train_ds, test_ds=data_generator.test_ds)

Epoch: 6:  10%|███▌                               | 5/50 [00:30<04:37,  6.17s/batch, test_loss=11.9]