In [1]:
import composuite
from diffusion.utils import *
from CORL.algorithms.offline.td3_bc import *
from CORL.shared.buffer import *
from CORL.shared.logger import *

base_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/results/CORL/'

In [2]:
def get_model_weights_norm(model):
    """
    Compute L2 norm of the model weights.
    """
    
    total_norm = 0.0
    for param in model.parameters():
        if param.requires_grad:
            param_norm = param.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5

    return total_norm

In [None]:
dataset_type = 'expert'

robot = 'IIWA'
obj = 'Box'
obst = 'None'
task = 'Push'

env = composuite.make(robot, obj, obst, task, use_task_id_obs=True, ignore_done=False)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

dataset = load_single_composuite_dataset(base_path=base_data_path, 
                                         dataset_type=dataset_type, 
                                         robot=robot, obj=obj, 
                                         obst=obst, task=task)
dataset = transitions_dataset(dataset)

In [4]:
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)

In [None]:
config = TrainConfig()

config.device = "cpu"
config.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
config.env = "iiwa-box-none-push"
config.diffusion.path = "/Users/shubhankar/Developer/compositional-rl-synth-data/results/Diffusion/IIWA_Box_None_Push/samples.npz"

replay_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    buffer_size=config.buffer_size,
    dataset=dataset,
    env_name=config.env,
    device=config.device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std),
    diffusion_config=config.diffusion
    )

In [6]:
max_action = float(env.action_space.high[0])

In [7]:
results_folder = os.path.join(base_results_folder, robot + '_' + obj + '_' + obst + '_' + task)

In [8]:
logger = Logger(results_folder, seed=config.seed)

# Set seeds
seed = config.seed
set_seed(seed, env)

actor = Actor(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

critic_1 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)

critic_2 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

In [None]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "critic_1": critic_1,
    "critic_1_optimizer": critic_1_optimizer,
    "critic_2": critic_2,
    "critic_2_optimizer": critic_2_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    # TD3
    "policy_noise": config.policy_noise * max_action,
    "noise_clip": config.noise_clip * max_action,
    "policy_freq": config.policy_freq,
    # TD3 + BC
    "alpha": config.alpha,
}

print("----------------------------------------------------")
print(f"Training TD3 + BC, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

# Initialize actor
trainer = TD3_BC(**kwargs)

In [None]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.split('/')[-1],
)

In [None]:
evaluations = []
for t in range(int(config.max_timesteps)):
    batch = replay_buffer.sample(config.batch_size)
    # batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='train')

    # Evaluate episode
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Actor norm: {get_model_weights_norm(actor)}")
        print(f"Critic 1 norm: {get_model_weights_norm(critic_1)}")
        print(f"Critic 2 norm: {get_model_weights_norm(critic_2)}")
        print(f"Time steps: {t + 1}")
        eval_scores = eval_actor(
            env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='eval')