In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import copy
import znrnd as znrnd
import numpy as np

import matplotlib.pyplot as plt

from neural_tangents import stax
import optax

# Using different Training strategies

When training a model, training in costum strategies can be curcial, especially when re-training a model on new data.

This notebook shows how to use different training strategies for RND. 

In [None]:
data_generator = znrnd.data.MNISTGenerator(50)
input_shape = data_generator.train_ds['inputs'][:1, ...].shape

## Define the Networks

In [None]:
architecture = stax.serial(
    stax.Flatten(),
    stax.Dense(128),
    stax.Relu(),
    stax.Dense(128)
)

In [None]:
target_model = znrnd.models.NTModel(
    nt_module=architecture,
    optimizer=optax.adam(learning_rate=0.02),
    input_shape=input_shape,
    batch_size=10,
)

predictor_model = znrnd.models.NTModel(
    nt_module=architecture,
    optimizer=optax.adam(learning_rate=0.02),
    input_shape=input_shape,
    batch_size=10,
)

In [None]:
# Create a data set for recording RND
dataset = {"inputs": data_generator.train_ds["inputs"], 
           "targets": target_model(data_generator.train_ds["inputs"])}

## Define the training strategies and according recorders

Here, 3 different training strategies are presented. 
For each the train loss is recorded to show the difference between the strategies. 

1. Simple Training 
2. Partitioned Training
3. Loss aware reservoir Training

### Simple Training

In [None]:
simple_recorder = znrnd.training_recording.JaxRecorder(
    name="simple_recorder",
    loss=True, 
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
simple_recorder.instantiate_recorder(
    data_set=dataset
)
simple_trainer = znrnd.training_strategies.SimpleTraining(
    model=None,
    loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
    recorders=[simple_recorder]
)

print("SimpleTraining: \n", simple_trainer.__doc__)

### Partitioned Training

We will use that method to only train the point latest chosen data point in RND. 

In [None]:
partitioned_recorder = znrnd.training_recording.JaxRecorder(
    name="simple_recorder",
    loss=True, 
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
partitioned_recorder.instantiate_recorder(
    data_set=dataset
)
partitioned_trainer = znrnd.training_strategies.PartitionedTraining(
    model=None,
    loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
    recorders=[partitioned_recorder]
)

print("PartitionedTraining: \n", simple_trainer.__doc__)

### Loss aware reservoir Training

In [None]:
LaR_recorder = znrnd.training_recording.JaxRecorder(
    name="simple_recorder",
    loss=True, 
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
LaR_recorder.instantiate_recorder(
    data_set=dataset
)
LaR_trainer = znrnd.training_strategies.LossAwareReservoir(
    model=None,
    loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
    reservoir_size=2,
    latest_points=1,
    recorders=[LaR_recorder]
)

print("LossAwareReservoir: \n", simple_trainer.__doc__)

In [None]:
agent_dict = {
    "data_generator": data_generator,
    "target_network": target_model,
    "predictor_network": predictor_model,
    "distance_metric": znrnd.distance_metrics.OrderNDifference(order=2),
    "point_selector": znrnd.point_selection.GreedySelection(), 
}

simple_agent = znrnd.agents.RND(
        training_strategy=simple_trainer,
        **copy.deepcopy(agent_dict),
    )
partitioned_agent = znrnd.agents.RND(
        training_strategy=partitioned_trainer,
        **copy.deepcopy(agent_dict),
    )
LaR_agent = znrnd.agents.RND(
        training_strategy=LaR_trainer,
        **copy.deepcopy(agent_dict),
    )

## Execution

In [None]:
target_size = 5
batch_size = 20
epochs = 50

_ = simple_agent.build_dataset(
    target_size=target_size, 
    epochs=epochs, 
    batch_size=batch_size,
    seed_randomly=False
)
_ = partitioned_agent.build_dataset(
    target_size=target_size, 
    epochs=[epochs], 
    batch_size=[batch_size],
    train_ds_selection=[np.array([-1])],
    seed_randomly=False, 
)
_ = LaR_agent.build_dataset(
    target_size=target_size, 
    epochs=epochs, 
    batch_size=batch_size,
    seed_randomly=False
)

## Plot the data

In [None]:
simple_report = simple_recorder.gather_recording()
pertitioned_report = partitioned_recorder.gather_recording()
LaR_report = LaR_recorder.gather_recording()

In [None]:
plt.plot(simple_report.loss, '-', mfc='None', label="SimpleTraining")
plt.plot(pertitioned_report.loss, '-', mfc='None', label="PartitionedTraining")
plt.plot(LaR_report.loss, '-', mfc='None', label="LossAwareReservoir")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Show Loss equalizing
The loss aware reservior and the partitioned training are designed to equalize initial loss inequalities in data. 
Here, we pre-train on a part of the data and check the loss for new, non-trained data. 
We then watch the loss decrease when training on all data using different training strategies. 

In [None]:
data_generator = znrnd.data.MNISTGenerator(50)
pre_train_ds = {k: v[:40, ...] for k, v in data_generator.train_ds.items()}

### Model

In [None]:
architecture = stax.serial(
    stax.Flatten(),
    stax.Dense(128),
    stax.Relu(),
    stax.Dense(128), 
    stax.Relu(), 
    stax.Dense(10)
)

In [None]:
model = znrnd.models.NTModel(
    nt_module=architecture,
    optimizer=optax.adam(learning_rate=0.02),
    input_shape=input_shape,
    batch_size=10,
)

### Pre-Training

In [None]:
pre_train_recorder = znrnd.training_recording.JaxRecorder(
    name="simple_recorder",
    loss=True, 
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
pre_train_recorder.instantiate_recorder(
    data_set=pre_train_ds
)
pre_trainer = znrnd.training_strategies.SimpleTraining(
    model=model,
    loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
    recorders=[pre_train_recorder]
)

In [None]:
_ = pre_trainer.train_model(train_ds=pre_train_ds, test_ds=pre_train_ds, epochs=50, batch_size=10)

### Check if the training has converged 

In [None]:
pre_train_report = pre_train_recorder.gather_recording()

In [None]:
plt.plot(pre_train_report.loss, '-', mfc='None', label="Pre-training")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.yscale("log")
plt.legend()
plt.show()

### Check the loss for all data, including unseen data

In [None]:
loss_metric = znrnd.loss_functions.MeanPowerLoss(order=2).metric

diff = loss_metric(
    pre_trainer.model(data_generator.train_ds["inputs"]), 
    data_generator.train_ds["targets"]
)

plt.plot(diff, 'o')
plt.xlabel("Data index")
plt.ylabel("Loss")
plt.show()

One can clearly see which part of the data was trained and which was not. 

## Prepair the training of different strategies

### Simple Training

In [None]:
simple_recorder = znrnd.training_recording.JaxRecorder(
    name="simple_recorder",
    loss=True, 
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
simple_recorder.instantiate_recorder(
    data_set=data_generator.train_ds
)
simple_trainer = znrnd.training_strategies.SimpleTraining(
    model=copy.deepcopy(model),
    loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
    recorders=[simple_recorder]
)

### Partitioned Training

We will use that method to only train the point latest chosen data point in RND. 

In [None]:
partitioned_recorder = znrnd.training_recording.JaxRecorder(
    name="simple_recorder",
    loss=True, 
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
partitioned_recorder.instantiate_recorder(
    data_set=data_generator.train_ds
)
partitioned_trainer = znrnd.training_strategies.PartitionedTraining(
    model=copy.deepcopy(model),
    loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
    recorders=[partitioned_recorder]
)

### Loss aware reservoir Training

In [None]:
LaR_recorder = znrnd.training_recording.JaxRecorder(
    name="simple_recorder",
    loss=True, 
    update_rate=1, 
    chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.
)
LaR_recorder.instantiate_recorder(
    data_set=data_generator.train_ds
)
LaR_trainer = znrnd.training_strategies.LossAwareReservoir(
    model=copy.deepcopy(model),
    loss_fn=znrnd.loss_functions.MeanPowerLoss(order=2),
    reservoir_size=10,
    latest_points=5,
    recorders=[LaR_recorder]
)

### Execution

In [None]:
_ = simple_trainer.train_model(
    train_ds=data_generator.train_ds, 
    test_ds=data_generator.train_ds, 
    epochs=100, 
    batch_size=10, 
)
_ = partitioned_trainer.train_model(
    train_ds=data_generator.train_ds, 
    test_ds=data_generator.train_ds, 
    epochs=[50, 50],
    batch_size=[5, 10],
    train_ds_selection=[slice(40, 50, None), slice(None, None, None)]
)
_ = LaR_trainer.train_model(
    train_ds=data_generator.train_ds, 
    test_ds=data_generator.train_ds, 
    epochs=100, 
    batch_size=10, 
)

## Plot the data

In [None]:
simple_report = simple_recorder.gather_recording()
pertitioned_report = partitioned_recorder.gather_recording()
LaR_report = LaR_recorder.gather_recording()

In [None]:
plt.plot(simple_report.loss, '-', mfc='None', label="SimpleTraining")
plt.plot(pertitioned_report.loss, '-', mfc='None', label="PartitionedTraining")
plt.plot(LaR_report.loss, '-', mfc='None', label="LossAwareReservoir")
plt.yscale("log")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()