In [1]:
from accelerate import Accelerator
from diffusion.utils import *
from CORL.algorithms.offline.td3_bc import *
from CORL.shared.buffer import *
from CORL.shared.logger import *

base_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/results/CORL/'

In [None]:
dataset_type = 'expert'

robot = 'IIWA'
obj = 'Box'
obst = 'None'
task = 'Push'

env = composuite.make(robot, obj, obst, task, use_task_id_obs=True, ignore_done=False)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

dataset = load_single_composuite_dataset(base_path=base_data_path, 
                                         dataset_type=dataset_type, 
                                         robot=robot, obj=obj, 
                                         obst=obst, task=task)
dataset = transitions_dataset(dataset)

In [3]:
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)

In [4]:
accelerator = Accelerator()

In [None]:
config = TrainConfig()

config.device = accelerator.device
config.env = "iiwa-box-none-push"
config.diffusion.path = "/Users/shubhankar/Developer/compositional-rl-synth-data/results/Diffusion/IIWA_Box_None_Push/samples.npz"

replay_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    buffer_size=config.buffer_size,
    dataset=dataset,
    env_name=config.env,
    device=config.device,
    reward_normalizer=RewardNormalizer(dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std),
    diffusion_config=config.diffusion
    )

In [6]:
max_action = float(env.action_space.high[0])

In [7]:
results_folder = os.path.join(base_results_folder, robot + '_' + obj + '_' + obst + '_' + task)

In [8]:
logger = Logger(results_folder, seed=config.seed)

# Set seeds
seed = config.seed
set_seed(seed, env)

actor = Actor(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

critic_1 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)

critic_2 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

In [9]:
class TD3_BC:  # noqa
    def __init__(
            self,
            max_action: float,
            actor: nn.Module,
            actor_optimizer: torch.optim.Optimizer,
            critic_1: nn.Module,
            critic_1_optimizer: torch.optim.Optimizer,
            critic_2: nn.Module,
            critic_2_optimizer: torch.optim.Optimizer,
            discount: float = 0.99,
            tau: float = 0.005,
            policy_noise: float = 0.2,
            noise_clip: float = 0.5,
            policy_freq: int = 2,
            alpha: float = 2.5,
    ):
        
        self.accelerator = Accelerator()

        self.actor = actor
        self.actor_target = copy.deepcopy(actor)
        self.actor_optimizer = actor_optimizer
        self.actor, self.actor_target = accelerator.prepare(self.actor, self.actor_target)
        self.actor_optimizer = accelerator.prepare(self.actor_optimizer)

        self.critic_1 = critic_1
        self.critic_1_target = copy.deepcopy(critic_1)
        self.critic_1_optimizer = critic_1_optimizer
        self.critic_1, self.critic_1_target = accelerator.prepare(self.critic_1, self.critic_1_target)
        self.critic_1_optimizer = accelerator.prepare(self.critic_1_optimizer)

        self.critic_2 = critic_2
        self.critic_2_target = copy.deepcopy(critic_2)
        self.critic_2_optimizer = critic_2_optimizer
        self.critic_2, self.critic_2_target = accelerator.prepare(self.critic_2, self.critic_2_target)
        self.critic_2_optimizer = accelerator.prepare(self.critic_2_optimizer)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.alpha = alpha

        self.total_it = 0

    def train(self, batch: TensorBatch) -> Dict[str, float]:
        log_dict = {}
        self.total_it += 1

        batch = self.accelerator.prepare(batch)
        state, action, reward, next_state, done = batch
        not_done = 1 - done

        with torch.no_grad():
            # Select action according to actor and add clipped noise
            noise = (torch.randn_like(action) * self.policy_noise).clamp(
                -self.noise_clip, self.noise_clip
            )

            next_action = (self.actor_target(next_state) + noise).clamp(
                -self.max_action, self.max_action
            )

            # Compute the target Q value
            target_q1 = self.critic_1_target(next_state, next_action)
            target_q2 = self.critic_2_target(next_state, next_action)
            target_q = torch.min(target_q1, target_q2)
            target_q = reward + not_done * self.discount * target_q

        # Get current Q estimates
        current_q1 = self.critic_1(state, action)
        current_q2 = self.critic_2(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        log_dict["critic_loss"] = critic_loss.item()

        # Optimize the critic
        self.critic_1_optimizer.zero_grad()
        self.critic_2_optimizer.zero_grad()
        self.accelerator.backward(critic_loss)
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.step()

        # Delayed actor updates
        if self.total_it % self.policy_freq == 0:
            # Compute actor loss
            pi = self.actor(state)
            q = self.critic_1(state, pi)
            lmbda = self.alpha / q.abs().mean().detach()

            actor_loss = -lmbda * q.mean() + F.mse_loss(pi, action)
            log_dict["actor_loss"] = actor_loss.item()
            # Optimize the actor
            self.actor_optimizer.zero_grad()
            self.accelerator.backward(actor_loss)
            self.actor_optimizer.step()

            # Update the frozen target models
            soft_update(self.critic_1_target, self.critic_1, self.tau)
            soft_update(self.critic_2_target, self.critic_2, self.tau)
            soft_update(self.actor_target, self.actor, self.tau)

        return log_dict

    def state_dict(self) -> Dict[str, Any]:
        return {
            "critic_1": self.critic_1.state_dict(),
            "critic_1_optimizer": self.critic_1_optimizer.state_dict(),
            "critic_2": self.critic_2.state_dict(),
            "critic_2_optimizer": self.critic_2_optimizer.state_dict(),
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.critic_1.load_state_dict(state_dict["critic_1"])
        self.critic_1_optimizer.load_state_dict(state_dict["critic_1_optimizer"])
        self.critic_1_target = copy.deepcopy(self.critic_1)

        self.critic_2.load_state_dict(state_dict["critic_2"])
        self.critic_2_optimizer.load_state_dict(state_dict["critic_2_optimizer"])
        self.critic_2_target = copy.deepcopy(self.critic_2)

        self.actor.load_state_dict(state_dict["actor"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
        self.actor_target = copy.deepcopy(self.actor)

        self.total_it = state_dict["total_it"]

In [None]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "critic_1": critic_1,
    "critic_1_optimizer": critic_1_optimizer,
    "critic_2": critic_2,
    "critic_2_optimizer": critic_2_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    # TD3
    "policy_noise": config.policy_noise * max_action,
    "noise_clip": config.noise_clip * max_action,
    "policy_freq": config.policy_freq,
    # TD3 + BC
    "alpha": config.alpha,
}

print("----------------------------------------------------")
print(f"Training TD3 + BC, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

# Initialize actor
trainer = TD3_BC(**kwargs)

In [None]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.split('/')[-1],
)

In [None]:
evaluations = []
for t in range(int(config.max_timesteps)):
    batch = replay_buffer.sample(config.batch_size)
    # batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='train')

    # Evaluate episode
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Time steps: {t + 1}")
        eval_scores = eval_actor(
            env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='eval')