# Homework 2 Part 1 - Behavior Cloning and Offline RL

***

Written by Albert Wilcox

In this homework, you'll implement DAgger and Implicit Q learning on the `halfcheetah-medium-replay-v2` task from the [D4RL benchmark](https://github.com/Farama-Foundation/D4RL).

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

import gym
import numpy as np
from loguru import logger
import matplotlib.pyplot as plt
from IPython.display import Image
from tqdm import tqdm
import einops
import os
import copy

from src.utils import (
    get_device,
    set_seed,
    eval_policy,
    demo_policy,
    plot_returns,
    save_frames_as_gif,
    update_exponential_moving_average,
    return_range
)
from src.d4rl_dataset import D4RLSampler

plt.ion()

***

### Part 0 - Setting up D4RL and Dataset

The first step for training on the D4RL benchmark is to set up the environment. Unfortunately mujoco can be difficult to install. Run the following block to test your install. If you have any issues, Google is your friend :)

In [None]:
import d4rl

Next we need to initialize our environment and set the random seed for a variety of libraries in order to ensure determinism in our experiments.

For this homework we'll be using the `halfcheetah-medium-replay-v2` environment. This environment involves training a two-legged cheetah to run, and comes with a dataset that consists of data from rolling out a suboptimal SAC agent and exploration data from the SAC replay buffer. Thus, it comes from a wide distribution of policies and contains a good amount of suboptimal data.

In [None]:
SEED: int = 42
ENVIRONMENT_NAME: str='halfcheetah-medium-replay-v2'

# torch related defaults
DEVICE = get_device()
torch.set_default_dtype(torch.float32)

# Use random seeds for reproducibility
set_seed(SEED)

# instantiate the environment
env = gym.make(ENVIRONMENT_NAME)

# get the state and action dimensions
action_dimension = env.action_space.shape[0]
state_dimension = env.observation_space.shape[0]

logger.info(f'Action Dimension: {action_dimension}')
logger.info(f'Action High: {env.action_space.high}')
logger.info(f'Action Low: {env.action_space.low}')
logger.info(f'State Dimension: {state_dimension}')


Next, we need a dataset. Luckily for us, D4RL provides datasets that are convenient to download and train on. Running the following command should download and cache the dataset and initialize a dataset object before printing out some useful information.

In [None]:
dataset = d4rl.qlearning_dataset(env)

logger.info(f'Dataset type: {type(dataset)}')
logger.info(f'Dataset keys: {dataset.keys()}')
logger.info(f'# Samples: {len(dataset["observations"])}')

In this cell we wrap the D4RL dataset in a sampler. You can comment out the lines at the bottom to make sure everything runs smoothly.

In [None]:

sampler = D4RLSampler(dataset, 256, DEVICE)

# Uncomment the following lines to iterate through the datset and make sure everything runs smoothly
# for _ in tqdm(sampler):
#     pass

### Part 1 - Behavior Cloning

In this part of the homework you'll implement behavior cloning.

Next, train a BC agent by minimizing the negative log likelihood (NLL) of the predicted distribution on datset actions.

You should achieve a maximum normalized reward greater than 0.35 with the provided hyperparameters.

In [None]:
from src.networks import GaussianPolicy

################################## Hyper-parameters #########################################

EPOCHS: int = 50
EVAL_FREQ = 5
LOAD_CKPT = False

hidden_dim: int = 256
n_hidden: int = 3
lr: float = 3e-4
WEIGHT_DECAY: float = 3e-4

#############################################################################################

bc_policy = GaussianPolicy(state_dimension, action_dimension, hidden_dim, n_hidden).to(DEVICE)
optimizer = Adam(bc_policy.parameters(), lr)

if LOAD_CKPT and os.path.exists('bc_policy.pth'):
    ckpt = torch.load('bc_policy.pth')
    bc_policy.load_state_dict(ckpt['state_dict'])
    means = ckpt['means']
    stds = ckpt['stds']
else:
    means, stds = [], []
    for epoch in range(EPOCHS):
        total_loss = 0
        for batch in tqdm(sampler):
            state = batch['state'].to(DEVICE)
            action = batch['action'].to(DEVICE)
            
            # TODO: compute negative log likelihood loss on this batch
            loss = torch.tensor(0)

            total_loss += loss.item()

        if (epoch + 1) % EVAL_FREQ == 0:
            rew_mean, rew_std = eval_policy(bc_policy, environment_name=ENVIRONMENT_NAME, eval_episodes=50)
            logger.info(f'Epoch: {epoch + 1}. Loss: {total_loss / len(sampler):.4f}. Reward: {rew_mean:.4f} +/- {rew_std:.4f}')
            means.append(rew_mean)
            stds.append(rew_std)
    # Save the policy and learning curve in case there is an issue so you can plot without retraining
    exp_state = {
        'state_dict': bc_policy.state_dict(),
        'means': means,
        'stds': stds
    }
    torch.save(exp_state, 'bc_policy.pth')
epochs = np.arange(EVAL_FREQ, EPOCHS + EVAL_FREQ, step=EVAL_FREQ)
plot_returns(means, stds,'Behavior Cloning', epochs=epochs, goal=0.35)

Now that we've finished training, use the following block to visualize the policy you trained with BC.

In [None]:
bc_policy.load_state_dict(torch.load('bc_policy.pth')['state_dict'])
frames, total_reward = demo_policy(bc_policy, environment_name=ENVIRONMENT_NAME, steps=200)
gif_path = save_frames_as_gif(frames, method_name='bc')
Image(open(gif_path,'rb').read())

<!-- ### Part 2 - DAgger

BC is great at replicating supervisor actions when the agent is in the data distribution, but this assumption is not always true. Sometimes the agent may enter an out of distribution state and output bad actions. A popular method to handle this issue is [Dataset Aggregation (DAgger)](https://arxiv.org/abs/1011.0686). The key idea behind DAgger is to roll out the learned policy while querying an expert policy on the states the agent encounters, adding the state-expert action pairs to the dataset. 

Luckily for you, we're providing an expert pretrained using the Soft Actor Critic Algorithm, which we'll load and test in the following block. -->

### Part 2 - Implicit Q Learning

In this part you'll implement Implicit Q-Learning (Kostrikov et al., 2021), a popular offline RL algorithm. 

The key idea behind IQL is to use expectile regression to optimize the value functions so that they estimate the values of the higher-performing actions in the dataset, rather than estimating the values of the current policy. This allows you to learn a value function without ever querying the policy, which helps to avoid OOD issues. We would suggest having a look at Kostrikov et al. for a more thorough description of the algorithm.

To start, implement a double Q function below. This can be similar to the code from HW1 but notice the constructor has a different signature.

In [None]:
from src.networks import network

class QNetwork(nn.Module):
    def __init__(self, state_dimension, action_dimension, hidden_dim, n_hidden):
        super(QNetwork, self).__init__()

        # TODO: fill in your code here

    def forward(self, state, action):

        # TODO: fill in your code here to query the critic
        return q1, q2

Next, implement a value network below. This should be similar to the Q network, but only condition on states and should only have one network.

In [None]:
class VNetwork(nn.Module):
    def __init__(self, state_dimension, hidden_dim, n_hidden):
        super(VNetwork, self).__init__()

        # TODO: your code here

    def forward(self, state):
        
        # TODO: your code here
        return v

Next, implement the expectile loss, $L^{\tau}_2$, to be used for optimizing the value function. This function is described in Sections 4.1 of Kostrikov et al.

In [None]:
def expectile_loss(diff, expectile=0.8):
    # TODO: fill in this function
    return None

Finally, it's time to implement the IQL training loop. There are several steps here:
 * Implement the value function update using the `expectile_loss` function implemented above.
 * Implement the $Q$ function update. The targets for this update should be a bellman backup based on the value function. Don't forget to update the EMA target!
 * Implement the policy update. This should be an NLL loss weighted based on clipped exponentiated advantage estimates

More details about all of these steps can be found in Kostrikov et al.

Once you've finished implementing the training loop run the cell to train your IQL policy. Your policy should get reward greater than 0.43 with a correct implementation and hyperparameters.

In [None]:
EPOCHS = 150
EVAL_FREQ = 15
LOAD_FROM_CKPT = False

# These parameters should work fine but you may tune them if you want to
hidden_dim: int = 256
n_hidden: int = 2
lr: float = 3e-4
discount = 0.99
alpha = 0.005
exp_advantage_max = 100

# TODO: you'll need to choose your own value for the following parameters
tau = ???
beta = ???

min_rew, max_rew = return_range(dataset, 1000)

#############################################################################################

sampler = D4RLSampler(dataset, 256, DEVICE)

iql_policy = GaussianPolicy(state_dimension, action_dimension, hidden_dim, n_hidden).to(DEVICE)
policy_optimizer = Adam(iql_policy.parameters(), lr)
policy_lr_schedule = CosineAnnealingLR(policy_optimizer, EPOCHS * len(sampler))

v_critic = VNetwork(state_dimension, hidden_dim, n_hidden).to(DEVICE)
v_optimizer = Adam(v_critic.parameters(), lr)

q_critic = QNetwork(state_dimension, action_dimension, hidden_dim, n_hidden).to(DEVICE)
q_critic_target = copy.deepcopy(q_critic)
q_critic_target.requires_grad_(False)
q_optimizer = Adam(q_critic.parameters(), lr)

means, stds, start_epoch = [], [], 0
if os.path.exists('iql_checkpoint.pth') and LOAD_FROM_CKPT:
    checkpoint = torch.load('iql_checkpoint.pth')

    iql_policy.load_state_dict(checkpoint['iql_policy'])
    policy_optimizer.load_state_dict(checkpoint['policy_optimizer'])
    v_critic.load_state_dict(checkpoint['v_critic'])
    v_optimizer.load_state_dict(checkpoint['v_optimizer'])
    q_critic.load_state_dict(checkpoint['q_critic'])
    q_critic_target.load_state_dict(checkpoint['q_critic_target'])
    q_optimizer.load_state_dict(checkpoint['q_optimizer'])
    
    start_epoch = checkpoint['epoch']
    means = checkpoint['means']
    stds = checkpoint['stds']
    
    print(f'Resuming run from epoch {start_epoch}')

for epoch in range(start_epoch, EPOCHS):
    total_q_loss = total_v_loss = total_policy_loss = count = 0
    policy_losses = []
    # for batch in tqdm(dataloader):
    for batch in tqdm(sampler):
        state = batch['state'].to(DEVICE)
        next_state = batch['next_state'].to(DEVICE)
        action = batch['action'].to(DEVICE)
        reward = einops.rearrange(batch['reward'], 'b -> b 1').to(DEVICE)
        reward = reward / (max_rew - min_rew) * 1000
        not_done = einops.rearrange(batch['not_done'], 'b -> b 1').to(DEVICE)

        # TODO: update the state value function (V)
        v_loss = torch.tensor(0)

        # TODO: update the state-action value function (Q) and the target
        q_loss = torch.tensor(0)

        # TODO: update the policy
        policy_loss = torch.tensor(0)

        policy_lr_schedule.step()
        total_v_loss += v_loss.item()
        total_q_loss += q_loss.item()
        total_policy_loss += policy_loss.item()
        count += 1
        
    if (epoch + 1) % EVAL_FREQ == 0:
        rew_mean, rew_std = eval_policy(iql_policy, environment_name=ENVIRONMENT_NAME, eval_episodes=50)
        print(f'Epoch: {epoch + 1}. Q Loss: {total_q_loss / count:.4f}. V Loss: {total_v_loss / count:.4f}. P Loss: {total_policy_loss / count:.4f}. Reward: {rew_mean:.4f} +/- {rew_std:.4f}')
        means.append(rew_mean)
        stds.append(rew_std)

    # Save a checkpoint so that you can resume training if it crashes
    checkpoint = {
        'iql_policy': iql_policy.state_dict(),
        'policy_optimizer': policy_optimizer.state_dict(),
        'v_critic': v_critic.state_dict(),
        'v_optimizer': v_optimizer.state_dict(),
        'q_critic': q_critic.state_dict(),
        'q_critic_target': q_critic_target.state_dict(),
        'q_optimizer': q_optimizer.state_dict(),
        'epoch': epoch + 1,
        'means': means,
        'stds': stds
    }
    torch.save(checkpoint, 'iql_checkpoint.pth')

epochs = np.arange(EVAL_FREQ, EPOCHS + EVAL_FREQ, step=EVAL_FREQ)
plot_returns(means, stds, 'Implicit Q Learning', goal=0.4, epochs=epochs)

Now that we've finished training our IQL policy we can visualize it in the block below.

In [None]:
iql_policy.load_state_dict(torch.load('iql_checkpoint.pth')['iql_policy'])
frames, total_reward = demo_policy(iql_policy, environment_name=ENVIRONMENT_NAME, steps=200)
gif_path = save_frames_as_gif(frames, method_name='iql')
Image(open(gif_path,'rb').read())

Congrats on finishing the offline portion of Assignment 2! Hopefully you enjoyed yourself. Make sure that the visualizations are showing and that there are four outputs from this notebook in the artifacts folder:
 * `bc_policy.gif`
 * `iql_policy.gif`
 * `Behavior Cloning_returns.png`
 * `Implicit Q Learning_results.png`

When you're done: export this notebook as an **HTML file** for final submission.