In [2]:
import pandas as pd
from time import time

import jax.numpy as jnp
from orbax.checkpoint import PyTreeCheckpointer

from lotf import LOTF_PATH
from lotf.utils.residual_dynamics import create_vec_funcs

# Training an Ensemble of Residual Dynamics Networks

## 1. Set Number of Models, Create Vectorized Functions

In [3]:
num_models = 3

# NOTE: these vectorized functions automatically broadcasts over arbitrary number of ensemble models
init_fn, train_fn, predict_fn = create_vec_funcs()

## 2. Load Dataset

In [4]:
dataset_name = "example_dataset.csv"

file_path = LOTF_PATH + "/../examples/residual_dynamics/" + dataset_name
df = pd.read_csv(file_path, header=None)
dataset = df.to_numpy()
print(f"Loaded dataset shape: {dataset.shape}")

input_dim = 19
output_dim = 3

Loaded dataset shape: (1000, 22)


## 3. Define Training Hyperparams

In [5]:
weight_init_scale = 1.0    # scale of weight initialization
learning_rate = 1e-2       # optimizer learning rate
lambda_reg = 1e-3          # weight norm regularization coefficient
num_epochs = 100
batch_size = 256
eval_every = 10

## 4. Initialize and Train

In [6]:
# initialize model params and train states
model_params, train_states = init_fn(
    learning_rate, jnp.arange(num_models, dtype=jnp.int32)
)

# prepare dataset
X, y = dataset[:, :input_dim], dataset[:, input_dim:]
X, y = jnp.array(X, dtype=jnp.float32), jnp.array(y, dtype=jnp.float32)

tic = time()
train_states = train_fn(train_states, X, y, lambda_reg, num_epochs, eval_every)
print(f"Residual model training took {time() - tic:.2f} seconds")

Epoch 0/100 | Train MSE: 4.412898540496826 | Total Loss: 4.418099403381348
Epoch 0/100 | Train MSE: 4.686639785766602 | Total Loss: 4.691734790802002
Epoch 0/100 | Train MSE: 2.74587082862854 | Total Loss: 2.750821352005005
Epoch 10/100 | Train MSE: 0.11661765724420547 | Total Loss: 0.12209253758192062
Epoch 10/100 | Train MSE: 0.08638668060302734 | Total Loss: 0.09179800003767014
Epoch 10/100 | Train MSE: 0.09987227618694305 | Total Loss: 0.10490693151950836
Epoch 20/100 | Train MSE: 0.022784674540162086 | Total Loss: 0.02915060706436634
Epoch 20/100 | Train MSE: 0.031967174261808395 | Total Loss: 0.03817407041788101
Epoch 20/100 | Train MSE: 0.02059531770646572 | Total Loss: 0.02646797150373459
Epoch 30/100 | Train MSE: 0.012355873361229897 | Total Loss: 0.019228212535381317
Epoch 30/100 | Train MSE: 0.013430574908852577 | Total Loss: 0.020221399143338203
Epoch 30/100 | Train MSE: 0.011919140815734863 | Total Loss: 0.018313447013497353
Epoch 40/100 | Train MSE: 0.005905880592763424 |

## 5. Save Model Params

In [7]:
model_name = f"my_residual_dynamics_params"

model_path = LOTF_PATH + "/../checkpoints/residual_dynamics/" + model_name
ckptr = PyTreeCheckpointer()
residual_params = train_states.params
ckptr.save(model_path, residual_params)
print("Saved model params!")

Saved model params!
