In [None]:
import time

import jax
import optax
from orbax.checkpoint import PyTreeCheckpointer
from flax.training.train_state import TrainState

from lotf import LOTF_PATH
from lotf.algos import bptt
from lotf.envs import HoveringFeaturesEnv
from lotf.envs.wrappers import LogWrapper, VecEnv
from lotf.modules import MLP
from lotf.objects import Quadrotor

# Finetuning a Trained Feature-Based Hovering Policy With BPTT

## 1. Seed

In [2]:
seed = 0
key = jax.random.key(seed)
key_init, key_bptt = jax.random.split(key, 2)

## 2. Define Simulation Dynamics Config and Training Params

In [3]:
# simulation dynamics config
sim_dyn_config = {
    "use_high_fidelity": False,          # whether to use high-fidelity dynamics in forward simulation
    "use_forward_residual": False,       # whether to use residual dynamics in forward simulation
}

# training parameters
num_envs = 300
max_epochs = 200

## 3. Create Quadrotor Object and Simulation Environment

In [4]:
# simulation parameters
sim_dt = 0.02
max_sim_time = 3.0

# quadrotor object
quad_obj = Quadrotor.from_name("example_quad", sim_dyn_config)

# simulation environment
env = HoveringFeaturesEnv(
    max_steps_in_episode=int(max_sim_time / sim_dt),
    dt=sim_dt,
    delay=0.04,
    yaw_scale=1.0,
    pitch_roll_scale=0.1,
    velocity_std=0.1,
    omega_std=0.1,
    quad_obj=quad_obj,
    reward_sharpness=2.0,
    action_penalty_weight=0.5,
    num_last_quad_states=15,
    skip_frames=3,
    margin=0.5,
    hover_target=[1.5, 0.0, 1.5],
)

# get dimensions
action_dim = env.action_space.shape[0]
obs_dim = env.observation_space.shape[0]

# apply additional wrappers
env = LogWrapper(env)
env = VecEnv(env)

print("====== env info ======")
print(f"action_dim: {action_dim}")
print(f"obs_dim: {obs_dim}")
print(f"target hover goal: {env.goal}")

action_dim: 4
obs_dim: 82
target hover goal: [1.5 0.  1.5]


## 4. Load Policy Parameters, Create Optimizer and Train State

In [5]:
policy_name = "vision_hovering_params"

# policy network and init parameters
policy_net = MLP(
    [obs_dim, 512, 512, action_dim],
    initial_scale=0.01,
    action_bias=env.hovering_action,
)
path = LOTF_PATH + "/../checkpoints/policy/" + policy_name
ckptr = PyTreeCheckpointer()
policy_params = ckptr.restore(path)

# optimizer
scheduler = optax.cosine_decay_schedule(1e-3, max_epochs)
tx = optax.adam(scheduler)

# train state object
train_state = TrainState.create(
    apply_fn=policy_net.apply, params=policy_params, tx=tx
)

## 5. Load Residual Dynamics Network Parameters

In [6]:
residual_dynamics_name = "example_params"

path = LOTF_PATH + "/../checkpoints/residual_dynamics/" + residual_dynamics_name
ckptr = PyTreeCheckpointer()
dummy_residual_params = ckptr.restore(path)

## 6. Train

In [7]:
# intialize environments
key_bptt, key_ = jax.random.split(key_bptt)
key_reset = jax.random.split(key_, num_envs)
init_env_state, init_obs = env.reset(key_reset, None)

# training loop
time_start = time.time()
res_dict = bptt.train(
    env,
    init_env_state,
    init_obs,
    train_state,
    num_epochs=max_epochs,
    num_steps_per_epoch=env.max_steps_in_episode,
    num_envs=num_envs,
    res_model_params=dummy_residual_params,
    key=key_bptt,
)
time_train_compile = time.time() - time_start
print(f"Compile + Training time: {time_train_compile}")

Episode: 0, Grad max: 0.1935
Episode: 0, Loss: 1.06
Episode: 10, Grad max: 0.8374
Episode: 10, Loss: 2.30
Episode: 20, Grad max: 0.4450
Episode: 20, Loss: 1.59
Episode: 30, Grad max: 0.4201
Episode: 30, Loss: 1.28
Episode: 40, Grad max: 0.2799
Episode: 40, Loss: 1.10
Episode: 50, Grad max: 0.2603
Episode: 50, Loss: 1.00
Episode: 60, Grad max: 0.2331
Episode: 60, Loss: 1.04
Episode: 70, Grad max: 0.1395
Episode: 70, Loss: 1.02
Episode: 80, Grad max: 0.1219
Episode: 80, Loss: 0.98
Episode: 90, Grad max: 0.1135
Episode: 90, Loss: 0.94
Episode: 100, Grad max: 0.1169
Episode: 100, Loss: 0.97
Episode: 110, Grad max: 0.1180
Episode: 110, Loss: 0.92
Episode: 120, Grad max: 0.1710
Episode: 120, Loss: 0.91
Episode: 130, Grad max: 0.0906
Episode: 130, Loss: 0.93
Episode: 140, Grad max: 0.0786
Episode: 140, Loss: 0.92
Episode: 150, Grad max: 0.0966
Episode: 150, Loss: 0.87
Episode: 160, Grad max: 0.0682
Episode: 160, Loss: 0.90
Episode: 170, Grad max: 0.0814
Episode: 170, Loss: 0.89
Episode: 180, 