In [None]:
import time

import jax
import optax
from orbax.checkpoint import PyTreeCheckpointer
from flax.training.train_state import TrainState
from flax.core import freeze, unfreeze

from lotf import LOTF_PATH
from lotf.algos import bptt
from lotf.envs import HoveringStateEnv
from lotf.envs.wrappers import MinMaxObservationWrapper, LogWrapper, VecEnv
from lotf.modules import MLP, LoraMLP
from lotf.objects import Quadrotor

from lotf.utils.lora import (
    lora_only_mask,
    partition_params,
    recursive_merge,
)

# (LoRA) Finetuning a Trained State-Based Hovering Policy With BPTT

## 1. Seed

In [2]:
seed = 0
key = jax.random.key(seed)
key_init, key_bptt = jax.random.split(key, 2)

## 2. Define Simulation Dynamics Config and Training Params

In [3]:
# simulation dynamics config
sim_dyn_config = {
    "use_high_fidelity": False,          # whether to use high-fidelity dynamics in forward simulation
    "use_forward_residual": False,       # whether to use residual dynamics in forward simulation
}

# training parameters
num_envs = 200
max_epochs = 200

## 3. Create Quadrotor Object and Simulation Environment

In [4]:
# simulation parameters
sim_dt = 0.02
max_sim_time = 3.0

# quadrotor object
quad_obj = Quadrotor.from_name("example_quad", sim_dyn_config)

# simulation environment
env = HoveringStateEnv(
    max_steps_in_episode=int(max_sim_time / sim_dt),
    dt=sim_dt,
    delay=0.04,
    yaw_scale=1.0,
    pitch_roll_scale=0.1,
    velocity_std=0.1,
    omega_std=0.1,
    quad_obj=quad_obj,
    reward_sharpness=3.0,
    action_penalty_weight=0.5,
    margin=0.5,
    hover_target=[1.5, 0.0, 1.5],
)

# apply min-max observation wrapper
env = MinMaxObservationWrapper(env)

# get dimensions
action_dim = env.action_space.shape[0]
obs_dim = env.observation_space.shape[0]

# apply additional wrappers
env = LogWrapper(env)
env = VecEnv(env)

print("====== env info ======")
print(f"action_dim: {action_dim}")
print(f"obs_dim: {obs_dim}")
print(f"target hover goal: {env.goal}")

action_dim: 4
obs_dim: 27
target hover goal: [1.5 0.  1.5]


## 4. Load Base Policy Parameters, Create Optimizer and Train State

In [5]:
policy_name = "state_hovering_params"

# policy network and init parameters
base_policy_net = MLP(
    [obs_dim, 512, 512, action_dim],
    initial_scale=0.01,
    action_bias=env.hovering_action,
)
path = LOTF_PATH + "/../checkpoints/policy/" + policy_name
ckptr = PyTreeCheckpointer()
base_policy_params = ckptr.restore(path)

## 5. Define LoRA Policy Network, Create Optimizer and Train State

In [6]:
lora_ranks = [1, 1, 1]
lora_alpha = 1.0

# LoRA policy network
policy_net = LoraMLP(base_mlp=base_policy_net, lora_ranks=lora_ranks, lora_alpha=lora_alpha)
policy_params = policy_net.initialize_with_base(key_init, base_policy_params)

mask = lora_only_mask(policy_params)
frozen_params, trainable_params = partition_params(policy_params, mask)
def apply_combined(params, x):
    full_params = freeze(recursive_merge(unfreeze(frozen_params), unfreeze(params)))
    return policy_net.apply(full_params, x)

# optimizer and train state
tx = optax.adam(learning_rate=1e-3)
train_state = TrainState.create(
    apply_fn=apply_combined, params=trainable_params, tx=tx
)

## 5. Load Residual Dynamics Network Parameters

In [7]:
residual_dynamics_name = "example_params"

path = LOTF_PATH + "/../checkpoints/residual_dynamics/" + residual_dynamics_name
ckptr = PyTreeCheckpointer()
dummy_residual_params = ckptr.restore(path)

## 6. Train

In [8]:
# intialize environments
key_bptt, key_ = jax.random.split(key_bptt)
key_reset = jax.random.split(key_, num_envs)
init_env_state, init_obs = env.reset(key_reset, None)

# training loop
time_start = time.time()
res_dict = bptt.train(
    env,
    init_env_state,
    init_obs,
    train_state,
    num_epochs=max_epochs,
    num_steps_per_epoch=env.max_steps_in_episode,
    num_envs=num_envs,
    res_model_params=dummy_residual_params,
    key=key_bptt,
)
time_train_compile = time.time() - time_start
print(f"Compile + Training time: {time_train_compile}")

Episode: 0, Grad max: 0.0198
Episode: 0, Loss: 0.78
Episode: 10, Grad max: 0.0151
Episode: 10, Loss: 0.73
Episode: 20, Grad max: 0.0291
Episode: 20, Loss: 0.72
Episode: 30, Grad max: 0.0227
Episode: 30, Loss: 0.77
Episode: 40, Grad max: 0.0171
Episode: 40, Loss: 0.75
Episode: 50, Grad max: 0.0122
Episode: 50, Loss: 0.74
Episode: 60, Grad max: 0.0112
Episode: 60, Loss: 0.76
Episode: 70, Grad max: 0.0271
Episode: 70, Loss: 0.74
Episode: 80, Grad max: 0.0122
Episode: 80, Loss: 0.76
Episode: 90, Grad max: 0.0396
Episode: 90, Loss: 0.76
Episode: 100, Grad max: 0.0484
Episode: 100, Loss: 0.77
Episode: 110, Grad max: 0.0358
Episode: 110, Loss: 0.81
Episode: 120, Grad max: 0.0264
Episode: 120, Loss: 0.75
Episode: 130, Grad max: 0.0117
Episode: 130, Loss: 0.75
Episode: 140, Grad max: 0.0230
Episode: 140, Loss: 0.76
Episode: 150, Grad max: 0.0325
Episode: 150, Loss: 0.75
Episode: 160, Grad max: 0.0159
Episode: 160, Loss: 0.76
Episode: 170, Grad max: 0.0143
Episode: 170, Loss: 0.73
Episode: 180, 