In [1]:
from diffusion.utils import *
from corl.algorithms import iql
from corl.shared.buffer import *
import wandb
import pathlib

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/monolithic_seed0_train98_1'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'           

In [2]:
config = iql.TrainConfig()

data_type = 'synthetic'

if data_type == 'synthetic':
    config.seed = 0
    synthetic_run_id = ''
    mode = ''  # train/test
    config.n_episodes = 10
    config.batch_size = 1024

config.max_timesteps = 50000

robot = 'Panda'
obj = 'Box'
obst = 'None'
subtask = 'Push'

env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                               dataset_type='expert', 
                                               robot=robot, obj=obj, 
                                               obst=obst, task=subtask)
agent_dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(agent_dataset))

integer_dims, constant_dims = identify_special_dimensions(agent_dataset['observations'])
print('Integer dimensions:', integer_dims)
print('Constant dimensions:', constant_dims)

if data_type == 'synthetic':
    synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
    synthetic_dataset = process_special_dimensions(synthetic_dataset, integer_dims, constant_dims)

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Integer dimensions: [14, 15, 16, 17, 18, 19, 20, 31, 32, 33, 34]
Constant dimensions: [28, 29, 30]


In [3]:
agent_dataset['observations'].shape

(999999, 77)

In [4]:
synthetic_dataset['observations'].shape

(1000000, 77)

In [5]:
base_results_path = pathlib.Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_{data_type}_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_{data_type}_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

config.checkpoints_path = results_folder

In [6]:
if data_type == 'agent':
    dataset = agent_dataset
    num_samples = int(dataset['observations'].shape[0])
elif data_type == 'synthetic':
    dataset = synthetic_dataset
    num_samples = int(dataset['observations'].shape[0])
print("Samples:", num_samples)

Samples: 1000000


In [7]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = iql.compute_mean_std(dataset["observations"], eps=1e-3)
env = iql.wrap_env(env, state_mean=state_mean, state_std=state_std)
eval_env = iql.wrap_env(env, state_mean=state_mean, state_std=state_std)

In [8]:
device = "cpu"

replay_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    dataset=dataset,
    num_samples=num_samples,
    device=device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std),
    )

Limiting size of the data to 1000000 samples.
Dataset size: 1000000


In [9]:
max_action = float(env.action_space.high[0])

In [10]:
# Set seeds
seed = config.seed
iql.set_seed(seed, env)

q_network = iql.TwinQ(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
q_optimizer = torch.optim.Adam(q_network.parameters(), lr=3e-4)

v_network = iql.ValueFunction(state_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
v_optimizer = torch.optim.Adam(v_network.parameters(), lr=3e-4)

actor = (
    iql.DeterministicPolicy(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth)
    if config.iql_deterministic else
    iql.GaussianPolicy(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth)
).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

In [11]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "q_network": q_network,
    "q_optimizer": q_optimizer,
    "v_network": v_network,
    "v_optimizer": v_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    "device": config.device,
    # IQL
    "beta": config.beta,
    "iql_tau": config.iql_tau,
    "max_steps": config.max_timesteps
}

print("----------------------------------------------------")
print(f"Training IQL, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

trainer = iql.ImplicitQLearning(**kwargs)

----------------------------------------------------
Training IQL, Env: , Seed: 0
----------------------------------------------------


In [12]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.name,
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
  from IPython.core.display import HTML, display  # type: ignore


In [13]:
print(config.checkpoints_path)

/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning/offline_learning_synthetic_23


In [14]:
config.offline_iterations = 25000 
config.online_iterations = 25000
config.iql_deterministic = True

In [15]:
config.max_timesteps = 25000

In [16]:
evaluations = []
for t in range(int(config.max_timesteps)):
    batch = replay_buffer.sample(config.batch_size)
    batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        wandb.log(log_dict, step=trainer.total_it)

    # Evaluate episode
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Time steps: {t + 1}")
        eval_scores = iql.eval_actor(
            env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        wandb.log(log_dict, step=trainer.total_it)

Time steps: 1
------------------------------------------------
Evaluation over 10 episodes: 3.708
------------------------------------------------
Time steps: 5001
------------------------------------------------
Evaluation over 10 episodes: 132.575
------------------------------------------------
Time steps: 10001
------------------------------------------------
Evaluation over 10 episodes: 59.372
------------------------------------------------
Time steps: 15001
------------------------------------------------
Evaluation over 10 episodes: 60.118
------------------------------------------------
Time steps: 20001
------------------------------------------------
Evaluation over 10 episodes: 14.553
------------------------------------------------
Time steps: 25000
------------------------------------------------
Evaluation over 10 episodes: 16.123
------------------------------------------------


In [None]:
evaluations = []

state, done = env.reset(), False
episode_return = 0
episode_step = 0

print("Offline pretraining.")
for t in range(int(config.offline_iterations) + int(config.online_iterations)):
    if t == config.offline_iterations:
        print("Online finetuning.")

    if t >= config.offline_iterations:
        episode_step += 1
        action = actor(
            torch.tensor(
                state.reshape(1, -1), device=config.device, dtype=torch.float32
            )
        )
        if not config.iql_deterministic:
            action = action.sample()
        else:
            noise = (torch.randn_like(action) * config.expl_noise).clamp(
                -config.noise_clip, config.noise_clip
            )
            action += noise
        action = torch.clamp(max_action * action, -max_action, max_action)
        action = action.cpu().data.numpy().flatten()
        next_state, reward, done, _ = env.step(action)
        episode_return += reward
        replay_buffer.add_transition(state, action, reward, next_state, done)
        state = next_state
        if done:
            state, done = env.reset(), False
            episode_return = 0
            episode_step = 0

    batch = replay_buffer.sample(config.batch_size)
    batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        wandb.log(log_dict, step=trainer.total_it)

    # Evaluate episode
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Time steps: {t + 1}")
        eval_scores = iql.eval_actor(
            eval_env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        wandb.log(log_dict, step=trainer.total_it)