In [1]:
import composuite
from diffusion.utils import *
from online.online_exp import *
from pathlib import Path

def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std

def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env

In [None]:
gin.parse_config_file("/Users/shubhankar/Developer/compositional-rl-synth-data/config/sac.gin")

In [3]:
base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'

base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

In [None]:
robot = 'Jaco'
obj = 'Plate'
obst = 'GoalWall'
subtask = 'PickPlace'

env_name = f"{robot}_{obj}_{obst}_{subtask}"

data_type = 'agent'

if data_type == 'synthetic':
    synthetic_run_id = 20
    mode = 'train'

if data_type == 'agent':
    env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
    dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                             dataset_type='expert', 
                                             robot=robot, obj=obj, 
                                             obst=obst, task=subtask)
    dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(dataset))

if data_type == 'synthetic':
    dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                            robot=robot, obj=obj, 
                                            obst=obst, task=subtask)

In [None]:
dataset.keys()

In [6]:
exp_name_full = f"redq_sac_{robot}_{obj}_{obst}_{subtask}"

In [7]:
base_results_path = Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_redq_sac_{data_type}_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_redq_sac_{data_type}_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

logger_kwargs = setup_logger_kwargs(exp_name_full, data_dir=results_folder)

In [8]:
seed = 0
epochs = 100
steps_per_epoch = 1000
max_ep_len = 1000
n_evals_per_epoch = 1

# Agent hyperparameters.
hidden_sizes = (256, 256)
replay_size = int(1e6)
batch_size = 1024
lr = 3e-4
gamma = 0.99
polyak = 0.995
alpha = 0.2
auto_alpha = True
target_entropy = 'auto'
start_steps = 5000
delay_update_steps = 'auto'
utd_ratio = 20
num_Q = 10
num_min = 2
q_target_mode = 'min'
policy_update_delay = 20
diffusion_buffer_size = int(1e6)
diffusion_sample_ratio = 0.5

# Diffusion hyperparameters.
retrain_diffusion_every = 10_000
num_samples = 100_000
diffusion_start = 0
disable_diffusion = True
print_buffer_stats = True
skip_reward_norm = True
model_terminals = False

# Bias evaluation.
evaluate_bias = True
n_mc_eval = 1000
n_mc_cutoff = 350
reseed_each_epoch = True

# W&B
project_name = 'diffusion_online'

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
total_steps = steps_per_epoch * epochs + 1

In [None]:
""" Set up logger. """
logger = EpochLogger(**logger_kwargs)

In [12]:
""" Set up environment/s. """
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
env_fn = lambda: wrap_env(composuite.make(robot, obj, obst, subtask), state_mean=state_mean, state_std=state_std)
env, test_env, bias_eval_env = env_fn(), env_fn(), env_fn()

In [None]:
""" Set up seeding. """
# Seed torch and numpy.
torch.manual_seed(seed)
np.random.seed(seed)

# seed environment along with env action space so that everything is properly seeded for reproducibility
def seed_all(epoch):
    seed_shift = epoch * 9999
    mod_value = 999999
    env_seed = (seed + seed_shift) % mod_value
    test_env_seed = (seed + 10000 + seed_shift) % mod_value
    bias_eval_env_seed = (seed + 20000 + seed_shift) % mod_value
    torch.manual_seed(env_seed)
    np.random.seed(env_seed)
    env.seed(env_seed)
    env.action_space.np_random.seed(env_seed)
    test_env.seed(test_env_seed)
    test_env.action_space.np_random.seed(test_env_seed)
    bias_eval_env.seed(bias_eval_env_seed)
    bias_eval_env.action_space.np_random.seed(bias_eval_env_seed)

seed_all(epoch=0)

In [14]:
""" Prepare to initialize agent. """
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
max_ep_len = 1000
# Action limit for clamping; assumes all dimensions share the same bound!
# Need .item() to convert it from NumPy float to Python float.
act_limit = env.action_space.high[0].item()
start_time = time.time()
sys.stdout.flush()

In [None]:
""" Initialize agent + buffer and begin training. """
agent_config = {
    'env_name': env_name,
    'hidden_sizes': hidden_sizes,
    'replay_size': replay_size,
    'batch_size': batch_size,
    'lr': lr,
    'gamma': gamma,
    'polyak': polyak,
    'alpha': alpha,
    'auto_alpha': auto_alpha,
    'target_entropy': target_entropy,
    'start_steps': start_steps,
    'delay_update_steps': delay_update_steps,
    'utd_ratio': utd_ratio,
    'num_Q': num_Q,
    'num_min': num_min,
    'q_target_mode': q_target_mode,
    'policy_update_delay': policy_update_delay,
}

wandb.init(project=project_name, name=logger_kwargs['exp_name'])
wandb.config.update(agent_config)

agent = REDQRLPDAgent(diffusion_buffer_size, diffusion_sample_ratio, 
                      env_name, obs_dim, act_dim, act_limit, device,
                      hidden_sizes, replay_size, batch_size,
                      lr, gamma, polyak,
                      alpha, auto_alpha, target_entropy,
                      start_steps, delay_update_steps,
                      utd_ratio, num_Q, num_min, q_target_mode,
                      policy_update_delay)

In [None]:
agent.replay_buffer.size

In [None]:
print(f'Interaction Buffer Fraction: {agent.replay_buffer.size/agent.replay_buffer.max_size}')
print(f'Diffusion Buffer Fraction: {agent.diffusion_buffer.size/agent.diffusion_buffer.max_size}')

In [18]:
observations = dataset['observations']
actions = dataset['actions']
rewards = dataset['rewards']
next_observations = dataset['next_observations']
terminals = dataset['terminals']

for o, a, r, o2, term in zip(observations, actions, rewards, next_observations, terminals):
    agent.diffusion_buffer.store(o, a, r, o2, term)

In [19]:
# Set up diffusion model.
diff_dims = obs_dim + act_dim + 1 + obs_dim
if model_terminals:
    diff_dims += 1
inputs = torch.zeros((128, diff_dims)).float()
if skip_reward_norm:
    skip_dims = [obs_dim + act_dim]
else:
    skip_dims = []

In [None]:
o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

for t in range(total_steps):
    
    a = agent.get_exploration_action(o, env)  # get action from agent
    o2, r, d, _ = env.step(a)  # step the env, get next observation, reward and done signal

    """ 
    Very important: Before we let agent store this transition,
    ignore the "done" signal if it comes from hitting the time
    horizon (that is, when it's an artificial terminal signal
    that isn't based on the agent's state).
    """
    ep_len += 1
    d = False if ep_len == max_ep_len else d

    agent.store_data(o, a, r, o2, d)  # give new data to replay buffer
    agent.train(logger)  # let agent update
    o = o2  # set obs to next obs
    ep_ret += r

    if d or (ep_len == max_ep_len):
        logger.store(EpRet=ep_ret, EpLen=ep_len)  # store episode return and length to logger
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0  # reset environment

    if not disable_diffusion and (t + 1) % retrain_diffusion_every == 0 and (t + 1) >= diffusion_start:
        print(f'Retraining diffusion model at step {t + 1}.')

        # Train new diffusion model.
        diffusion_trainer = REDQTrainer(
            construct_diffusion_model(
                inputs=inputs,
                skip_dims=skip_dims,
                disable_terminal_norm=model_terminals,
            ),
            results_folder=args.results_folder,
            model_terminals=model_terminals,
        )
        diffusion_trainer.update_normalizer(agent.replay_buffer, device=device)
        diffusion_trainer.train_from_redq_buffer(agent.replay_buffer)
        agent.reset_diffusion_buffer()

        # Add samples to replay buffer.
        generator = SimpleDiffusionGenerator(env=env, ema_model=diffusion_trainer.ema.ema_model)
        observations, actions, rewards, next_observations, terminals = generator.sample(num_samples=num_samples)

        print(f'Adding {num_samples} samples to replay buffer.')
        for o, a, r, o2, term in zip(observations, actions, rewards, next_observations, terminals):
            agent.diffusion_buffer.store(o, a, r, o2, term)

        if print_buffer_stats:
            ptr_location = agent.replay_buffer.ptr
            real_observations = agent.replay_buffer.obs1_buf[:ptr_location]
            real_actions = agent.replay_buffer.acts_buf[:ptr_location]
            real_next_observations = agent.replay_buffer.obs2_buf[:ptr_location]
            real_rewards = agent.replay_buffer.rews_buf[:ptr_location]
            print('Buffer stats:')
            for i in range(observations.shape[1]):
                print(f'Diffusion Obs {i}: {np.mean(observations[:, i]):.2f} {np.std(observations[:, i]):.2f}')
                print(
                    f'     Real Obs {i}: {np.mean(real_observations[:, i]):.2f} {np.std(real_observations[:, i]):.2f}')
            for i in range(actions.shape[1]):
                print(f'Diffusion Action {i}: {np.mean(actions[:, i]):.2f} {np.std(actions[:, i]):.2f}')
                print(f'     Real Action {i}: {np.mean(real_actions[:, i]):.2f} {np.std(real_actions[:, i]):.2f}')
            print(f'Diffusion Reward: {np.mean(rewards):.2f} {np.std(rewards):.2f}')
            print(f'     Real Reward: {np.mean(real_rewards):.2f} {np.std(real_rewards):.2f}')
            print(f'Replay buffer size: {ptr_location}')
            print(f'Diffusion buffer size: {agent.diffusion_buffer.ptr}')

    # End of epoch wrap-up.
    if (t + 1) % steps_per_epoch == 0:
        epoch = t // steps_per_epoch

        # Test the performance of the deterministic version of the agent.
        returns = test_agent(agent, test_env, max_ep_len, logger, n_evals_per_epoch)  # add logging here
        if evaluate_bias:
            log_bias_evaluation(bias_eval_env, agent, logger, max_ep_len, alpha, gamma, n_mc_eval, n_mc_cutoff)

        # reseed should improve reproducibility (should make results the same whether bias evaluation is on or not)
        if reseed_each_epoch:
            seed_all(epoch)

        """ Logging. """
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('TotalEnvInteracts', t)
        logger.log_tabular('Time', time.time() - start_time)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('TestEpRet', with_min_and_max=True)
        logger.log_tabular('TestEpLen', average_only=True)
        logger.log_tabular('Q1Vals', with_min_and_max=True)
        logger.log_tabular('LossQ1', average_only=True)
        logger.log_tabular('LogPi', with_min_and_max=True)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('Alpha', with_min_and_max=True)
        logger.log_tabular('LossAlpha', average_only=True)
        logger.log_tabular('PreTanh', with_min_and_max=True)

        if evaluate_bias:
            logger.log_tabular("MCDisRet", with_min_and_max=True)
            logger.log_tabular("MCDisRetEnt", with_min_and_max=True)
            logger.log_tabular("QPred", with_min_and_max=True)
            logger.log_tabular("QBias", with_min_and_max=True)
            logger.log_tabular("QBiasAbs", with_min_and_max=True)
            logger.log_tabular("NormQBias", with_min_and_max=True)
            logger.log_tabular("QBiasSqr", with_min_and_max=True)
            logger.log_tabular("NormQBiasSqr", with_min_and_max=True)

        """ W&B """
        wandb.log(logger.log_current_row, step=t)
        logger.dump_tabular()

        sys.stdout.flush()  # flush logged information to disk