In [1]:
from diffusion.utils import *
from CORL.algorithms.offline.td3_bc import *
from CORL.shared.buffer import *
from CORL.shared.logger import *

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'

base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

In [2]:
config = TrainConfig()

data_type = 'synthetic'

if data_type == 'synthetic':
    synthetic_run_id = 'cond_diff_20'
    mode = 'train'
    config.max_timesteps = 50000
    config.n_episodes = 10
    config.batch_size = 1024

robot = 'Kinova3'
obj = 'Hollowbox'
obst = 'None'
subtask = 'Trashcan'

if data_type == 'agent':
    env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
    dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                             dataset_type='expert', 
                                             robot=robot, obj=obj, 
                                             obst=obst, task=subtask)
    dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(dataset))

if data_type == 'synthetic':
    dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                            robot=robot, obj=obj, 
                                            obst=obst, task=subtask)

In [3]:
base_results_path = pathlib.Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_{data_type}_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_{data_type}_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

config.checkpoints_path = results_folder

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)

In [None]:
dataset['observations'].shape

In [None]:
device = "cpu"

if data_type == 'agent':
    num_samples = int(0.005 * dataset['observations'].shape[0])
elif data_type == 'synthetic':
    num_samples = int(dataset['observations'].shape[0])

print("Samples:", num_samples)

replay_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    dataset=dataset,
    num_samples=num_samples,
    device=device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std),
    )

In [7]:
max_action = float(env.action_space.high[0])

In [8]:
logger = Logger(results_folder, seed=config.seed)

# Set seeds
seed = config.seed
set_seed(seed, env)

actor = Actor(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

critic_1 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)

critic_2 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

In [None]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "critic_1": critic_1,
    "critic_1_optimizer": critic_1_optimizer,
    "critic_2": critic_2,
    "critic_2_optimizer": critic_2_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    "policy_noise": config.policy_noise * max_action,
    "noise_clip": config.noise_clip * max_action,
    "policy_freq": config.policy_freq,
    "alpha": config.alpha,
}

print("----------------------------------------------------")
print(f"Training TD3 + BC, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

# Initialize actor
trainer = TD3_BC(**kwargs)

In [None]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.name,
)

In [None]:
print(config.checkpoints_path)

In [None]:
evaluations = []
for t in range(int(config.max_timesteps)):
    batch = replay_buffer.sample(config.batch_size)
    batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='train')

    # Evaluate episode
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Time steps: {t + 1}")
        eval_scores = eval_actor(
            env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='eval')

In [None]:
print(state_mean.mean(), state_std.mean())

In [None]:
def get_weights_norm(model):
    total_norm = 0.0
    for param in model.parameters():
        if param.requires_grad:
            total_norm += param.norm(2).item() ** 2
    return total_norm ** 0.5

print('After:', get_weights_norm(actor))

In [17]:
state = env.reset()
env.viewer.set_camera(camera_id=3)

low, high = env.action_spec

# do visualization
for _ in range(1000):
    action = actor.act(state)
    state, _, done, _ = env.step(action)
    if done:
        break
    env.render()