In [None]:
from diffusion.utils import *
from corl.algorithms.finetune.iql import *
from corl.shared.buffer import *
from corl.shared.logger import *
import pickle

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'

base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

In [None]:
def combine_two_tensors(tensor1, tensor2):
    return torch.cat([tensor1, tensor2], dim=0)

class JointReplayBuffer:
    def __init__(self, diffusion_buffer, interaction_buffer, diffusion_sample_ratio=0.5, device="cpu"):
        self.diffusion_buffer = diffusion_buffer
        self.interaction_buffer = interaction_buffer
        self.diffusion_sample_ratio = diffusion_sample_ratio
        self.device = device

    def sample(self, batch_size):
        diffusion_batch_size = int(batch_size * self.diffusion_sample_ratio)
        interaction_batch_size = batch_size - diffusion_batch_size

        if self.interaction_buffer._size < interaction_batch_size:
            return self.diffusion_buffer.sample(batch_size)

        diffusion_batch = self.diffusion_buffer.sample(diffusion_batch_size)
        interaction_batch = self.interaction_buffer.sample(interaction_batch_size)

        observations = combine_two_tensors(interaction_batch[0], diffusion_batch[0]).to(self.device)
        actions = combine_two_tensors(interaction_batch[1], diffusion_batch[1]).to(self.device)
        rewards = combine_two_tensors(interaction_batch[2], diffusion_batch[2]).to(self.device)
        next_observations = combine_two_tensors(interaction_batch[3], diffusion_batch[3]).to(self.device)
        terminals = combine_two_tensors(interaction_batch[4], diffusion_batch[4]).to(self.device)

        return observations, actions, rewards, next_observations, terminals

In [None]:
config = TrainConfig()

data_type = 'synthetic'

config.batch_size = 1024
config.n_episodes = 10
config.offline_iterations = 5000
config.online_iterations = 0

config.iql_deterministic = True
config.expl_noise = 1e-3

if data_type == 'synthetic':
    mode = 'train'
    synthetic_run_id = 'cond_diff_20'
    config.eval_freq = 1000
    config.offline_iterations = 10000
    config.online_iterations = 0

robot = 'IIWA'
obj = 'Dumbbell'
obst = 'ObjectDoor'
subtask = 'Trashcan'

# robot = 'Kinova3'
# obj = 'Hollowbox'
# obst = 'None'
# subtask = 'Trashcan'

# if data_type == 'agent':
#     env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
#     dataset = load_single_composuite_dataset(base_path=base_agent_data_path,
#                                              dataset_type='expert',
#                                              robot=robot, obj=obj,
#                                              obst=obst, task=subtask)
#     dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(dataset))

# if data_type == 'synthetic':
#     dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode),
#                                             robot=robot, obj=obj,
#                                             obst=obst, task=subtask)

with open("filtered_transitions.pkl", "rb") as f:
    dataset = pickle.load(f)

In [None]:
dataset.keys()

In [None]:
base_results_path = pathlib.Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_iql_filt_{data_type}_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_iql_filt_{data_type}_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

config.checkpoints_path = results_folder

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)

eval_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)

In [None]:
config.device = "cpu"

if data_type == 'agent':
    # num_samples = int(dataset['observations'].shape[0])
    num_samples = 100000
elif data_type == 'synthetic':
    num_samples = int(dataset['observations'].shape[0])
    num_samples = 100000

print("Samples:", num_samples)

In [None]:
max_action = float(env.action_space.high[0])

In [None]:
logger = Logger(results_folder, seed=config.seed)

# Set seeds
seed = config.seed
set_seed(seed, env)
set_seed(seed, eval_env)

q_network = TwinQ(state_dim, action_dim).to(config.device)
q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr)

v_network = ValueFunction(state_dim).to(config.device)
v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr)

actor = (
    DeterministicPolicy(
        state_dim, action_dim, max_action, dropout=config.actor_dropout
    )
    if config.iql_deterministic
    else GaussianPolicy(
        state_dim, action_dim, max_action, dropout=config.actor_dropout
    )
).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)

In [None]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "q_network": q_network,
    "q_optimizer": q_optimizer,
    "v_network": v_network,
    "v_optimizer": v_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    "device": config.device,
    # IQL
    "beta": config.beta,
    "iql_tau": config.iql_tau,
    "max_steps": config.offline_iterations,
}

# kwargs = {
#     "max_action": max_action,
#     "actor": actor,
#     "actor_optimizer": actor_optimizer,
#     "q_network": q_network,
#     "q_optimizer": q_optimizer,
#     "v_network": v_network,
#     "v_optimizer": v_optimizer,
#     "discount": config.discount,
#     "tau": config.tau,
#     "device": config.device,
#     # IQL
#     "beta": config.beta,
#     "iql_tau": config.iql_tau,
#     "max_steps": config.online_iterations,
# }

print("------------------------------------------------")
print(f"Training IQL, Env: {config.env}, Seed: {seed}")
print("------------------------------------------------")

# Initialize actor
trainer = ImplicitQLearning(**kwargs)

In [None]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.name
)

In [None]:
diffusion_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    dataset=dataset,
    num_samples=num_samples,
    device=config.device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std)
)

interaction_buffer = ReplayBuffer(
    state_dim=state_dim,
    action_dim=action_dim,
    buffer_size=num_samples,
    device=config.device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std)
)

replay_buffer = JointReplayBuffer(diffusion_buffer, interaction_buffer, device=config.device)

In [None]:
evaluations = []

state, done = env.reset(), False
episode_return = 0
episode_step = 0

eval_successes = []
train_successes = []

print("Offline pretraining.")
for t in range(int(config.offline_iterations) + int(config.online_iterations)):
    if t == config.offline_iterations:
        print("Online finetuning.")
    online_log = {}
    if t >= config.offline_iterations:
        episode_step += 1
        action = actor(
            torch.tensor(
                state.reshape(1, -1), device=config.device, dtype=torch.float32
            )
        )
        if not config.iql_deterministic:
            action = action.sample()
        else:
            noise = (torch.randn_like(action) * config.expl_noise).clamp(
                -config.noise_clip, config.noise_clip
            )
            action += noise
        action = torch.clamp(max_action * action, -max_action, max_action)
        action = action.cpu().data.numpy().flatten()
        next_state, reward, done, env_infos = env.step(action)
        episode_return += reward

        replay_buffer.interaction_buffer.add_transition(state, action, reward, next_state, done)
        # interaction_buffer.add_transition(state, action, reward, next_state, done)
        state = next_state
        if done:
            state, done = env.reset(), False
            online_log["train/episode_return"] = episode_return
            online_log["train/episode_length"] = episode_step
            episode_return = 0
            episode_step = 0

    batch = replay_buffer.sample(config.batch_size)
    # if interaction_buffer._size <  2 * config.batch_size:
    #     continue
    # batch = interaction_buffer.sample(config.batch_size)
    batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
        t if t < config.offline_iterations else t - config.offline_iterations
    )
    if t % config.log_every == 0:
        log_dict.update(online_log)
        wandb.log(log_dict, step=trainer.total_it)
    # Evaluate episode
    if (t + 1) % config.eval_freq == 0:
        print('Diffusion buffer size:', diffusion_buffer._size, 'Interaction buffer size:', interaction_buffer._size)
        # print('Interaction buffer size:', interaction_buffer._size)
        print(f"Time steps: {t + 1}")
        eval_scores, success_rate = eval_actor(
            eval_env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        eval_log = {'Score': eval_score}
        evaluations.append(eval_score)
        print("---------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("---------------------------------------")
        if config.checkpoints_path is not None:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        wandb.log(eval_log, step=trainer.total_it)