In [None]:
from diffusion.utils import *
from corl.algorithms import sac_n
from corl.shared.buffer import *
from tqdm import trange
import wandb
import pathlib

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/monolithic_seed0_train98_1'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

In [None]:
data_type = 'agent'

config = sac_n.TrainConfig()
config.seed = 0
config.n_episodes = 10
config.batch_size = 1024
config.max_timesteps = 50000

synthetic_run_id = ''
mode = ''  # train/test

robot = 'Panda'
obj = 'Box'
obst = 'None'
subtask = 'Push'

env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                               dataset_type='expert', 
                                               robot=robot, obj=obj, 
                                               obst=obst, task=subtask)
agent_dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(agent_dataset))

integer_dims, constant_dims = identify_special_dimensions(agent_dataset['observations'])
print('Integer dimensions:', integer_dims)
print('Constant dimensions:', constant_dims)

if data_type == 'synthetic':
    synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
    synthetic_dataset = process_special_dimensions(synthetic_dataset, integer_dims, constant_dims)

In [None]:
agent_dataset['observations'].shape

In [None]:
# synthetic_dataset['observations'].shape

In [None]:
base_results_path = pathlib.Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_{data_type}_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_{data_type}_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

config.checkpoints_path = results_folder

In [None]:
if data_type == 'agent':
    dataset = agent_dataset
    num_samples = int(dataset['observations'].shape[0])
elif data_type == 'synthetic':
    dataset = synthetic_dataset
    num_samples = int(dataset['observations'].shape[0])
print("Samples:", num_samples)

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = sac_n.compute_mean_std(dataset["observations"], eps=1e-3)
env = sac_n.wrap_env(env, state_mean=state_mean, state_std=state_std)

In [None]:
device = "cpu"

# replay_buffer = prepare_replay_buffer(
#     state_dim=state_dim,
#     action_dim=action_dim,
#     dataset=dataset,
#     num_samples=num_samples,
#     device=device,
#     reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
#     state_normalizer=StateNormalizer(state_mean, state_std),
#     )

replay_buffer = sac_n.ReplayBuffer(
    state_dim=state_dim,
    action_dim=action_dim,
    buffer_size=num_samples,
    device=device,
)
replay_buffer.load_d4rl_dataset(dataset)

In [None]:
max_action = float(env.action_space.high[0])
config.max_action = max_action

In [None]:
# Set seeds
seed = config.seed
sac_n.set_seed(seed, env)

actor = sac_n.Actor(state_dim, action_dim, config.hidden_dim, config.max_action)
actor.to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_learning_rate)
critic = sac_n.VectorizedCritic(state_dim, action_dim, config.hidden_dim, config.num_critics)
critic.to(config.device)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=config.critic_learning_rate)

In [None]:
print("----------------------------------------------------")
print(f"Training SAC, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

trainer = sac_n.SACN(
    actor=actor,
    actor_optimizer=actor_optimizer,
    critic=critic,
    critic_optimizer=critic_optimizer,
    gamma=config.gamma,
    tau=config.tau,
    alpha_learning_rate=config.alpha_learning_rate,
    device=config.device,
)

In [None]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.name,
)

In [None]:
print(config.checkpoints_path)

In [None]:
total_updates = 0
evaluations = []
for epoch in trange(config.num_epochs, desc="Training"):
    # training
    for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False):
        batch = replay_buffer.sample(config.batch_size)
        update_info = trainer.update(batch)
        if total_updates % config.log_every == 0:
            wandb.log({"epoch": epoch, **update_info})
        total_updates += 1
    # evaluation
    if epoch % config.eval_every == 0 or epoch == config.num_epochs - 1:
        print(f"Running evaluation at epoch {epoch + 1}")
        eval_returns = sac_n.eval_actor(
            env=env,
            actor=actor,
            n_episodes=config.eval_episodes,
            seed=config.eval_seed,
            device=config.device,
        )
        eval_score = np.mean(eval_returns)
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.eval_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        eval_log = {
            "eval/reward_mean": eval_score,
            "eval/reward_std": np.std(eval_returns),
            "epoch": epoch,
        }
        wandb.log(eval_log)
        if config.checkpoints_path is not None:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"{epoch}.pt"),
            )