In [6]:
from diffusion.utils import *
from corl.algorithms import td3_bc
from corl.shared.buffer import *
import wandb
import pathlib

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

In [3]:
config = td3_bc.TrainConfig()

data_type = 'synthetic'

if data_type == 'synthetic':
    config.seed = 0
    synthetic_run_id = ''
    mode = ''  # train/test
    config.n_episodes = 10
    config.batch_size = 1024

config.max_timesteps = 50000

robot = 'Panda'
obj = 'Box'
obst = 'None'
subtask = 'PickPlace'

env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                               dataset_type='expert', 
                                               robot=robot, obj=obj, 
                                               obst=obst, task=subtask)
agent_dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(agent_dataset))

integer_dims, constant_dims = identify_special_dimensions(agent_dataset['observations'])
print('Integer dimensions:', integer_dims)
print('Constant dimensions:', constant_dims)

if data_type == 'synthetic':
    synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
    synthetic_dataset = process_special_dimensions(synthetic_dataset, integer_dims, constant_dims)

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Integer dimensions: [14, 15, 16, 17, 18, 19, 20, 31, 32, 33, 34]
Constant dimensions: [28, 29, 30]


In [4]:
agent_dataset['observations'].shape

(999999, 77)

In [5]:
synthetic_dataset['observations'].shape

(1000000, 77)

In [7]:
base_results_path = pathlib.Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_{data_type}_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_{data_type}_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

config.checkpoints_path = results_folder

In [8]:
if data_type == 'agent':
    dataset = agent_dataset
    num_samples = int(dataset['observations'].shape[0])
elif data_type == 'synthetic':
    dataset = synthetic_dataset
    num_samples = int(dataset['observations'].shape[0])
print("Samples:", num_samples)

Samples: 1000000


In [9]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = td3_bc.compute_mean_std(dataset["observations"], eps=1e-3)
env = td3_bc.wrap_env(env, state_mean=state_mean, state_std=state_std)

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [10]:
device = "cpu"

replay_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    dataset=dataset,
    num_samples=num_samples,
    device=device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std),
    )

Limiting size of the data to 1000000 samples.
Dataset size: 1000000


In [11]:
max_action = float(env.action_space.high[0])

In [12]:
# Set seeds
seed = config.seed
td3_bc.set_seed(seed, env)

actor = td3_bc.Actor(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

critic_1 = td3_bc.Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)

critic_2 = td3_bc.Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

In [13]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "critic_1": critic_1,
    "critic_1_optimizer": critic_1_optimizer,
    "critic_2": critic_2,
    "critic_2_optimizer": critic_2_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    "policy_noise": config.policy_noise * max_action,
    "noise_clip": config.noise_clip * max_action,
    "policy_freq": config.policy_freq,
    "alpha": config.alpha,
}

print("----------------------------------------------------")
print(f"Training TD3 + BC, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

trainer = td3_bc.TD3_BC(**kwargs)

----------------------------------------------------
Training TD3 + BC, Env: , Seed: 0
----------------------------------------------------


In [14]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.name,
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mspatank[0m ([33mspatank-upenn[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011123030088886784, max=1.0…

In [15]:
print(config.checkpoints_path)

/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning/offline_learning_synthetic_5


In [16]:
evaluations = []
for t in range(int(config.max_timesteps)):
    batch = replay_buffer.sample(config.batch_size)
    batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        wandb.log(log_dict, step=trainer.total_it)

    # evaluate actor
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Time steps: {t + 1}")
        eval_scores = td3_bc.eval_actor(
            env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        wandb.log(log_dict, step=trainer.total_it)

Time steps: 1


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.91s/it]


------------------------------------------------
Evaluation over 10 episodes: 0.552
------------------------------------------------
Time steps: 5001


Evaluating actor: 100%|██████████| 10/10 [00:20<00:00,  2.04s/it]


------------------------------------------------
Evaluation over 10 episodes: 29.586
------------------------------------------------
Time steps: 10001


Evaluating actor: 100%|██████████| 10/10 [00:20<00:00,  2.02s/it]


------------------------------------------------
Evaluation over 10 episodes: 31.599
------------------------------------------------
Time steps: 15001


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]


------------------------------------------------
Evaluation over 10 episodes: 28.938
------------------------------------------------
Time steps: 20001


Evaluating actor: 100%|██████████| 10/10 [00:20<00:00,  2.04s/it]


------------------------------------------------
Evaluation over 10 episodes: 39.300
------------------------------------------------
Time steps: 25001


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]


------------------------------------------------
Evaluation over 10 episodes: 33.712
------------------------------------------------
Time steps: 30001


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.92s/it]


------------------------------------------------
Evaluation over 10 episodes: 26.810
------------------------------------------------
Time steps: 35001


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.93s/it]


------------------------------------------------
Evaluation over 10 episodes: 28.458
------------------------------------------------
Time steps: 40001


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]


------------------------------------------------
Evaluation over 10 episodes: 41.391
------------------------------------------------
Time steps: 45001


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.98s/it]


------------------------------------------------
Evaluation over 10 episodes: 46.364
------------------------------------------------
Time steps: 50000


Evaluating actor: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]

------------------------------------------------
Evaluation over 10 episodes: 49.011
------------------------------------------------



