# Counterpol on LunarLander-v2

This notebook shows how to estimate counterfactuals for A2C policies on LunarLander-v2.

In [None]:
# Importing required libraries
import argparse
import os
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3 import A2C
from torch.distributions import Categorical
from stable_baselines3.common.policies import obs_as_tensor
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import CheckpointCallback

import gym
import numpy as np

from stable_baselines3.common.policies import obs_as_tensor

import wandb
from tqdm import tqdm


### Necessary functions

In [None]:
# Function for loading A2C policy
def load_saved_model(env, model_id, save_path, name_prefix):
    model = A2C.load(os.path.join(save_path, name_prefix + "_" + str(model_id) + "_steps"), env=env)
    return model


def get_action(policy_network, state):
    '''
    Function to sample action from policy network at a given state
    '''
    state = state.view(1, -1)
    probs = policy_network.get_distribution(state).distribution.probs
    probs = probs.squeeze(0)
    action = probs.argmax()
    return action.item(), probs[action], torch.log(probs[action])


def evaluate_performance(collected_episodes):
    '''
    Function to evaluate J_pi from trajectories collected using policy pi
    '''
    episodic_returns = []
    for episode in collected_episodes:
        episodic_returns.append(sum([transition[3] for transition in episode]))
    J_pi = np.mean(episodic_returns)
    return J_pi


def compute_stepwise_returns(collected_episodes):
    '''
    Returns Monte Carlo estimate of Q_pi(s,a) for each step in each episode
    '''
    per_step_returns = []
    for episode in collected_episodes:
        per_step_returns_ = []
        for transition in episode[::-1]:
            per_step_returns_.append(
                transition[3] + (per_step_returns_[-1] if len(per_step_returns_) > 0 else 0))
        per_step_returns.append(per_step_returns_[::-1])
    return per_step_returns


def rollout_policy(env, policy_network, device, num_rollouts=10, max_rollout_length=100):
    '''
    Collects trajectories using policy pi
    '''
    collected_episodes = []
    total_steps = 0

    for rollout_id in range(num_rollouts):
        episode = []
        state = env.reset()
        step_id = 0

        while True:
            if type(state) == int:
                state = np.array([state])
            state = torch.from_numpy(state).to(device)
            action, _, log_action_prob = get_action(policy_network, state)
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, log_action_prob, reward, next_state, done))
            state = next_state
            step_id += 1

            if done or step_id >= max_rollout_length:
                collected_episodes.append(episode)
                total_steps += step_id
                break

    return collected_episodes, total_steps


def calculate_counterfactual_objective(env, policy_network_orig, policy_network_cf, target_return, device, kl_weight=0.1, num_rollouts=10, max_rollout_length=100, reward_range=200):
    '''
    Computing Counterpol objective function
    '''

    collected_episodes, total_steps = rollout_policy(env, policy_network_cf, device, num_rollouts, max_rollout_length)
    per_step_returns = compute_stepwise_returns(collected_episodes)
    J_pi = evaluate_performance(collected_episodes)
    objective = 0
    kl_objective = 0
    performance_objective_orig = (J_pi - target_return)**2 / (reward_range**2)

    for episode_id, episode in enumerate(collected_episodes):
        for step_id, (state, action, log_action_prob, reward, next_state, done) in enumerate(episode):
            _, pi_orig_action, _ = get_action(policy_network_orig, state)
            kl_objective = (-1.0 * pi_orig_action.detach() * log_action_prob) / total_steps
            performance_objective = 2 * (J_pi - target_return) * (per_step_returns[episode_id][step_id] - target_return) * log_action_prob / (reward_range**2 * total_steps)
            objective += (kl_weight * kl_objective + (1 - kl_weight) * performance_objective)

    return objective, kl_objective, performance_objective_orig, J_pi



### Perform Counterpol Optimization 

In [None]:
target_returns = [100, 150, 0, -50]

for target_return in target_returns:
    config = {
        'env': 'LunarLander-v2',
        'learning_rate': 0.0005,
        'kl_weight': 0.9,
        'num_rollouts': 10,
        'model_id': 100000,
        'num_pol_iters': 10,
        'max_cf_iters': 250,
        'max_rollout_length': 500,
        'reward_range': 500,
        'target_return': target_return,
        'delta_return': 5,
        'grad_clip_value': 0.5,
        'pretrained_save_path': './orig_ckpts/',
        'seed': 0
    }

    # Set seeds
    seed = config['seed']
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create environment
    env = gym.make(config['env'])

    save_path = config['pretrained_save_path']
    name_prefix = config['env'].lower()

    # Logging
    wandb.init(
        project="RL-Factuals",
        config=config,
    )

    # Load pretrained model
    model_id = config['model_id']
    model = load_saved_model(env, model_id, save_path, name_prefix)

    # Counterpol starts -->

    # Create copy of original policy network
    policy_network_orig = model.policy
    policy_network_cf_0 = deepcopy(policy_network_orig)
    policy_network_cf = deepcopy(policy_network_orig)

    # Set up optimizer
    optimizer = torch.optim.Adam(
        policy_network_cf.parameters(), lr=config['learning_rate'])

    # Number of policy iterations
    num_pol_iters = config['num_pol_iters']
    max_cf_iters = config['max_cf_iters']

    found_counterfactual = False

    # Outerloop
    for _ in tqdm(range(num_pol_iters)):
        
        # Innerloop
        for _ in tqdm(range(max_cf_iters)):
            optimizer.zero_grad()

            cf_objective, kl_objective, performance_objective_orig, J_pi =\
                calculate_counterfactual_objective(env,
                                                    policy_network_cf_0,
                                                    policy_network_cf,
                                                    target_return=config['target_return'],
                                                    device=device,
                                                    kl_weight=config['kl_weight'],
                                                    num_rollouts=config['num_rollouts'],
                                                    max_rollout_length=config['max_rollout_length'],
                                                    reward_range=config['reward_range'])

            wandb.log({'counterfactual_objective': kl_objective+performance_objective_orig,
                        'J_pi': J_pi,
                        'kl_objective': kl_objective,
                        'performance_objective_orig': performance_objective_orig})

            # Check if counterfactual found
            if abs(J_pi - config['target_return']) < config['delta_return']:
                print("Found counterfactual")
                found_counterfactual = True
                # Break innerloop
                break
            else:
                cf_objective.backward()                
                torch.nn.utils.clip_grad_norm_(
                    policy_network_cf.parameters(), config['grad_clip_value'])

                optimizer.step()

        # Break outerloop
        if found_counterfactual:
            break
        
        # Change KL-pivot
        with torch.no_grad():
            policy_network_cf_0.load_state_dict(policy_network_cf.state_dict())

    # Evaluate the counterfactual policy
    model.policy = policy_network_cf
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
    print(f"Counterfactual policy J_pi:{mean_reward:.2f} +/- {std_reward}")

    # Save counterfactual policy
    model.save(f"./logs/lunarlander/cf_orig_id_{model_id}_target_return_{target_return}")

### Compare Original Policy with Counterfactual Policy

In [None]:
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv

video_folder = './logs/videos/'
video_length = 1000

# Load original model
model_id = config['model_id']
orig_model = load_saved_model(env, model_id, save_path, name_prefix)

env = DummyVecEnv([lambda: gym.make(config['env'])])

obs = env.reset()

# Record the video starting at the first step
env = VecVideoRecorder(env, video_folder,
                       record_video_trigger=lambda x: x == 0, video_length=video_length,
                       name_prefix="orig-{}".format(config['env']))

env.reset()
for _ in range(video_length + 1):
  action = orig_model.predict(obs)[0]
  obs, _, done, _ = env.step(action)
  if done:
    break
# Save the video
env.close()


# Note the counterfactual policy is saved as model in the previous cell
env = DummyVecEnv([lambda: gym.make(config['env'])])

obs = env.reset()

# Record the video starting at the first step
env = VecVideoRecorder(env, video_folder,
                       record_video_trigger=lambda x: x == 0, video_length=video_length,
                       name_prefix=f"counterfactual-target-return-{config['target_return']}-{config['env']}")

env.reset()
for _ in range(video_length + 1):
  action = model.predict(obs)[0]
  obs, _, _, done = env.step(action)
  if done:
    break
# Save the video
env.close()
