In [1]:
from CompoSuite import composuite
from diffusion.utils import *
from CORL.algorithms.offline.td3_bc import *
from CORL.shared.buffer import *
from CORL.shared.logger import *

base_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/results'

In [2]:
def get_model_weights_norm(model):
    """
    Compute L2 norm of the model weights.
    """
    
    total_norm = 0.0
    for param in model.parameters():
        if param.requires_grad:
            param_norm = param.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5

    return total_norm

In [3]:
dataset_type = 'expert-iiwa-offline-comp-data'

robot = 'IIWA'
obj = 'Box'
obst = 'None'
task = 'Push'

env = composuite.make(robot, obj, obst, task, use_task_id_obs=True, ignore_done=False)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

dataset = load_single_composuite_dataset(base_path=base_data_path, 
                                         dataset_type=dataset_type, 
                                         robot=robot, obj=obj, 
                                         obst=obst, task=task)
dataset = transitions_dataset(dataset)

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [4]:
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)

In [5]:
config = TrainConfig()

config.device = "cpu"
config.env = "iiwa-box-none-push"
config.diffusion.path = "results/IIWA_Box_None_Push/samples.npz"

replay_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    buffer_size=config.buffer_size,
    dataset=dataset,
    env_name=config.env,
    device=config.device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std),
    diffusion_config=config.diffusion
    )

Loading diffusion dataset.
Dataset size: 100000


In [6]:
max_action = float(env.action_space.high[0])

In [7]:
logger = Logger('/tmp', seed=config.seed)

# Set seeds
seed = config.seed
set_seed(seed, env)

actor = Actor(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

critic_1 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)
critic_2 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

In [8]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "critic_1": critic_1,
    "critic_1_optimizer": critic_1_optimizer,
    "critic_2": critic_2,
    "critic_2_optimizer": critic_2_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    "device": config.device,
    # TD3
    "policy_noise": config.policy_noise * max_action,
    "noise_clip": config.noise_clip * max_action,
    "policy_freq": config.policy_freq,
    # TD3 + BC
    "alpha": config.alpha,
}

print("----------------------------------------------------")
print(f"Training TD3 + BC, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

# Initialize actor
trainer = TD3_BC(**kwargs)

----------------------------------------------------
Training TD3 + BC, Env: iiwa-box-none-push, Seed: 0
----------------------------------------------------


In [9]:
evaluations = []
for t in range(int(config.max_timesteps)):
    batch = replay_buffer.sample(config.batch_size)
    batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        # wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='train')

    # Evaluate episode
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Actor norm: {get_model_weights_norm(actor)}")
        print(f"Critic 1 norm: {get_model_weights_norm(critic_1)}")
        print(f"Critic 2 norm: {get_model_weights_norm(critic_2)}")
        print(f"Time steps: {t + 1}")
        eval_scores = eval_actor(
            env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        # wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='eval')

Actor norm: 13.19299666969193
Critic 1 norm: 13.155180289335997
Critic 2 norm: 13.114130705451856
Time steps: 1


Evaluating actor: 100%|██████████| 10/10 [00:18<00:00,  1.88s/it]


------------------------------------------------
Evaluation over 10 episodes: 0.089
------------------------------------------------
Actor norm: 15.08174433063142
Critic 1 norm: 65.83032761685512
Critic 2 norm: 65.75891002053643
Time steps: 5001


Evaluating actor: 100%|██████████| 10/10 [00:16<00:00,  1.67s/it]


------------------------------------------------
Evaluation over 10 episodes: 9.201
------------------------------------------------
Actor norm: 17.62735111869298
Critic 1 norm: 236.27942610319982
Critic 2 norm: 236.77336742861152
Time steps: 10001


Evaluating actor: 100%|██████████| 10/10 [00:15<00:00,  1.53s/it]


------------------------------------------------
Evaluation over 10 episodes: 11.653
------------------------------------------------
Actor norm: 20.29941948910086
Critic 1 norm: 711.4182520323676
Critic 2 norm: 708.0150425994506
Time steps: 15001


Evaluating actor: 100%|██████████| 10/10 [00:13<00:00,  1.39s/it]


------------------------------------------------
Evaluation over 10 episodes: 11.311
------------------------------------------------
Actor norm: 22.82956916520786
Critic 1 norm: 1195.0958473827764
Critic 2 norm: 1187.5355087988928
Time steps: 20001


Evaluating actor: 100%|██████████| 10/10 [00:14<00:00,  1.48s/it]


------------------------------------------------
Evaluation over 10 episodes: 9.194
------------------------------------------------
Actor norm: 25.129045527570405
Critic 1 norm: 1615.0450040124801
Critic 2 norm: 1604.727885858945
Time steps: 25001


Evaluating actor: 100%|██████████| 10/10 [00:13<00:00,  1.35s/it]


------------------------------------------------
Evaluation over 10 episodes: 10.441
------------------------------------------------
Actor norm: 27.227453842504868
Critic 1 norm: 2009.531343939449
Critic 2 norm: 1997.0903095080546
Time steps: 30001


Evaluating actor: 100%|██████████| 10/10 [00:13<00:00,  1.38s/it]


------------------------------------------------
Evaluation over 10 episodes: 11.799
------------------------------------------------
Actor norm: 29.151000781507
Critic 1 norm: 2392.631730409874
Critic 2 norm: 2378.321068114729
Time steps: 35001


Evaluating actor: 100%|██████████| 10/10 [00:15<00:00,  1.58s/it]


------------------------------------------------
Evaluation over 10 episodes: 11.161
------------------------------------------------
Actor norm: 30.927704393427643
Critic 1 norm: 2771.5984894606568
Critic 2 norm: 2755.5558700656416
Time steps: 40001


Evaluating actor: 100%|██████████| 10/10 [00:12<00:00,  1.24s/it]


------------------------------------------------
Evaluation over 10 episodes: 16.428
------------------------------------------------
Actor norm: 32.57533102698499
Critic 1 norm: 3149.5098039892114
Critic 2 norm: 3131.6713722371856
Time steps: 45001


Evaluating actor: 100%|██████████| 10/10 [00:13<00:00,  1.40s/it]


------------------------------------------------
Evaluation over 10 episodes: 16.040
------------------------------------------------
Actor norm: 34.118729094739656
Critic 1 norm: 3527.123686184603
Critic 2 norm: 3507.3105798111656
Time steps: 50000


Evaluating actor: 100%|██████████| 10/10 [00:15<00:00,  1.52s/it]

------------------------------------------------
Evaluation over 10 episodes: 12.623
------------------------------------------------



