In [1]:
import pandas as pd
from time import time

import jax.numpy as jnp
from orbax.checkpoint import PyTreeCheckpointer

from lotf import LOTF_PATH
from lotf.utils.residual_dynamics import create_vec_funcs

# Training an Ensemble of Residual Dynamics Networks

## 1. Set Number of Models, Create Vectorized Functions

In [2]:
num_models = 3

# NOTE: these vectorized functions automatically broadcasts over arbitrary number of ensemble models
init_fn, train_fn, predict_fn = create_vec_funcs()

## 2. Load Dataset

In [3]:
dataset_name = "example_dataset.csv"

file_path = LOTF_PATH + "/../examples/residual_dynamics/" + dataset_name
df = pd.read_csv(file_path, header=None)
dataset = df.to_numpy()
print(f"Loaded dataset shape: {dataset.shape}")

input_dim = 19
output_dim = 3

Loaded dataset shape: (1000, 22)


## 3. Define Training Hyperparams

In [4]:
weight_init_scale = 1.0    # scale of weight initialization
learning_rate = 1e-2       # optimizer learning rate
lambda_reg = 1e-3          # weight norm regularization coefficient
num_epochs = 100
batch_size = 256
eval_every = 10

## 4. Initialize and Train

In [5]:
# initialize model params and train states
model_params, train_states = init_fn(
    learning_rate, jnp.arange(num_models, dtype=jnp.int32)
)

# prepare dataset
X, y = dataset[:, :input_dim], dataset[:, input_dim:]
X, y = jnp.array(X, dtype=jnp.float32), jnp.array(y, dtype=jnp.float32)

tic = time()
train_states = train_fn(train_states, X, y, lambda_reg, num_epochs, eval_every)
print(f"Residual model training took {time() - tic:.2f} seconds")

Epoch 0/100 | Train MSE: 4.412898063659668 | Total Loss: 4.4180989265441895
Epoch 0/100 | Train MSE: 4.686639785766602 | Total Loss: 4.691734790802002
Epoch 0/100 | Train MSE: 2.74587082862854 | Total Loss: 2.750821352005005
Epoch 10/100 | Train MSE: 0.11662287265062332 | Total Loss: 0.12209775298833847
Epoch 10/100 | Train MSE: 0.08662877231836319 | Total Loss: 0.09204016625881195
Epoch 10/100 | Train MSE: 0.09959358721971512 | Total Loss: 0.10462775826454163
Epoch 20/100 | Train MSE: 0.02277565561234951 | Total Loss: 0.029141608625650406
Epoch 20/100 | Train MSE: 0.036559805274009705 | Total Loss: 0.04275143891572952
Epoch 20/100 | Train MSE: 0.02065705507993698 | Total Loss: 0.026527877897024155
Epoch 30/100 | Train MSE: 0.012339760549366474 | Total Loss: 0.019211947917938232
Epoch 30/100 | Train MSE: 0.015150896273553371 | Total Loss: 0.021953297778964043
Epoch 30/100 | Train MSE: 0.011963981203734875 | Total Loss: 0.018357669934630394
Epoch 40/100 | Train MSE: 0.005908987484872341

## 5. Save Model Params

In [6]:
model_name = f"my_residual_dynamics_params"

model_path = LOTF_PATH + "/../checkpoints/residual_dynamics/" + model_name
ckptr = PyTreeCheckpointer()
residual_params = train_states.params
ckptr.save(model_path, residual_params)
print("Saved model params!")

Saved model params!
