In [1]:
import hydra
from omegaconf import OmegaConf, DictConfig
import pickle
import os

import gym
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from ppo.model import PPO
from hca.model import HCAModel
from hca.buffer import calculate_mc_returns
from dualdice.dd_model import DualDICE
from dualdice.return_model import ReturnPredictor
from utils import get_env

D4RL is not installed!
Lorl env is not installed!


Load Checkpoint

In [3]:
checkpoint_folder = "../checkpoints/hca-dualdice_delayed_GridWorld:test_v4_2023-06-14_01:00:19"
checkpoint_ep = 100000
ppo_checkpoint = torch.load(f"{checkpoint_folder}/ppo_{checkpoint_ep}.pt")
args = ppo_checkpoint["args"]
hca_checkpoint = torch.load(f"{checkpoint_folder}/hca_{checkpoint_ep}.pt")
dd_checkpoint = torch.load(f"{checkpoint_folder}/dd_{checkpoint_ep}.pt")
ret_checkpoint = torch.load(f"{checkpoint_folder}/ret_{checkpoint_ep}.pt")

Set up Environment

In [4]:
env = get_env(args)

if isinstance(env.action_space, gym.spaces.Box):
    continuous = True
else:
    continuous = False

if continuous:
    action_dim = env.action_space.shape[0]
else:
    action_dim = env.action_space.n

if args.env.type == "gridworld":
    # gridworld env
    input_dim = env.observation_space["map"].shape[0] + 1
else:
    input_dim = env.observation_space.shape[0]

env.reset()
env.render()

╒═══╤═══╤═══╤═══╤═══╕
│ . │ * │ . │ G │ * │
├───┼───┼───┼───┼───┤
│ . │ F │ . │ * │ . │
├───┼───┼───┼───┼───┤
│ . │ F │ . │ . │ . │
├───┼───┼───┼───┼───┤
│ . │ F │ * │ F │ . │
├───┼───┼───┼───┼───┤
│ A │ . │ . │ * │ . │
╘═══╧═══╧═══╧═══╧═══╛


Set up device

In [5]:
device = torch.device(args.training.device)

Initialize PPO policy

In [6]:
ppo_cnn = None # just leave this for now
agent = PPO(input_dim, action_dim, args.agent.lr, continuous, device, args, ppo_cnn)
agent.load(ppo_checkpoint)

Initialize HCA model

In [7]:
hca_cnn = None # just leave this for now
h_model = HCAModel(
            input_dim + 1,  # +1 is for return-conditioned
            action_dim,
            continuous=continuous,
            cnn_base=hca_cnn,
            n_layers=args.agent.hca_n_layers,
            hidden_size=args.agent.hca_hidden_size,
            activation_fn=args.agent.hca_activation,
            dropout_p=args.agent.hca_dropout,
            batch_size=args.agent.hca_batchsize,
            lr=args.agent.hca_lr,
            device=args.training.device,
            normalize_inputs=args.agent.hca_normalize_inputs,
            normalize_return_inputs_only=args.agent.hca_normalize_return_inputs_only,
            max_grad_norm=args.agent.hca_max_grad_norm,
            weight_training_samples=args.agent.hca_weight_training_samples,
            noise_std=args.agent.hca_noise_std,
        )
h_model.load(hca_checkpoint)

Load DualDICE model

In [8]:
dd_act_dim = action_dim if continuous else 1
dd_cnn = None # just leave this for now
dd_model = DualDICE(
            input_dim,
            action_dim=dd_act_dim,
            cnn_base=dd_cnn,  # using different CNNs here not worried about compute
            f=args.agent.dd_f,
            c=1, # args.agent.dd_c when available
            n_layers=args.agent.hca_n_layers,
            hidden_size=args.agent.hca_hidden_size,
            activation_fn=args.agent.hca_activation,
            dropout_p=args.agent.hca_dropout,
            batch_size=args.agent.hca_batchsize,
            lr=args.agent.hca_lr,
            device=args.training.device,
            normalize_inputs=args.agent.hca_normalize_inputs,
            normalize_return_inputs_only=args.agent.hca_normalize_return_inputs_only,
            max_grad_norm=args.agent.dd_max_grad_norm,
        )
dd_model.load(dd_checkpoint)

Load Return Model

In [9]:
r_cnn = None # just leave this for now
r_model = ReturnPredictor(
            input_dim,
            quantize=args.agent.r_quant,
            num_classes=args.agent.r_num_classes,
            cnn_base=r_cnn,  # using different CNNs here not worried about compute
            n_layers=args.agent.hca_n_layers,
            hidden_size=args.agent.hca_hidden_size,
            activation_fn=args.agent.hca_activation,
            dropout_p=args.agent.hca_dropout,
            batch_size=args.agent.hca_batchsize,
            lr=args.agent.hca_lr,
            device=args.training.device,
            normalize_inputs=args.agent.hca_normalize_inputs,
            normalize_targets=args.agent.r_normalize_targets,
            max_grad_norm=args.agent.r_max_grad_norm,
        )
r_model.load(ret_checkpoint)

Get hindsight ratios

In [10]:
def get_hindsight_ratios(state, action, returns):
    state = torch.from_numpy(state).reshape(1, -1).float().to(device)
    action = torch.from_numpy(action).reshape(1, -1).float().to(device)
    returns = torch.Tensor([returns]).reshape(1, -1).float().to(device)
    
    # Policy value
    pi_logprobs, _, _ = agent.policy.evaluate(state, action)
    pi_logprobs = pi_logprobs.detach().cpu().numpy()
    pi_prob = np.exp(pi_logprobs)

    # Hindsight model value
    h_logprobs = h_model.get_hindsight_logprobs(state, returns, action).cpu().numpy()
    h_prob = np.exp(h_logprobs)

    # H-DICE model value
    hdice_density_ratio = dd_model.get_density_ratios(state, action, returns)
    ret_prob = r_model.get_return_probs(state, returns)

    return {"pi": pi_prob, "h": h_prob, "naive_ratio": np.exp(pi_logprobs - h_logprobs), "dd": hdice_density_ratio, "ret": ret_prob, "hdice_ratio": hdice_density_ratio * ret_prob}

Collect an episode trajectory

In [22]:
def get_episode(env, ep_actions = []):
    state = env.reset()
        
    done = False

    states = []
    actions = []
    rewards = []
    terminals = []
    ep_len = 0

    gn_ep_actions = (ep_actions != [])

    while True:
        if gn_ep_actions:
            action = ep_actions.pop()
        else:
            # select action with policy
            action, _ = agent.select_action(state, greedy=True)
        if continuous:
            action = action.numpy().flatten()
            action = action.clip(
                env.action_space.low, env.action_space.high
            )
        else:
            action = action.item() if not isinstance(action, int) else action

        # Step in env
        state, reward, done, info = env.step(action)

        states.append(state)
        actions.append(action)
        
        # saving reward and terminals
        rewards.append(float(reward))
        terminals.append(done)

        ep_len += 1

        if done:
            break

        if (gn_ep_actions and not ep_actions):
            break

    assert ep_len == len(actions)
    
    return states, actions, rewards, terminals

GridWorld Actions

In [12]:
LEFT = 1
RIGHT = 0
UP = 3
DOWN = 2

In [13]:
ep1 = [RIGHT, RIGHT, UP, UP, UP]
env.reset()
for action in ep1:
    env.step(action)
env.render()

╒═══╤═══╤═══╤═══╤═══╕
│ . │ * │ . │ G │ * │
├───┼───┼───┼───┼───┤
│ . │ F │ A │ * │ . │
├───┼───┼───┼───┼───┤
│ . │ F │ . │ . │ . │
├───┼───┼───┼───┼───┤
│ . │ F │ . │ F │ . │
├───┼───┼───┼───┼───┤
│ S │ . │ . │ * │ . │
╘═══╧═══╧═══╧═══╧═══╛


In [14]:
curr_state = env.get_state()
print(get_hindsight_ratios(curr_state, np.array([LEFT]), -100))
print(get_hindsight_ratios(curr_state, np.array([LEFT]), 69))
print(get_hindsight_ratios(curr_state, np.array([RIGHT]), -100))
print(get_hindsight_ratios(curr_state, np.array([RIGHT]), 69))

{'pi': array([[0.00220553]], dtype=float32), 'h': array([0.34851924], dtype=float32), 'naive_ratio': array([[0.00632828]], dtype=float32), 'dd': tensor([[1.0000]]), 'ret': tensor([[0.1390]], dtype=torch.float64), 'hdice_ratio': tensor([[0.1390]], dtype=torch.float64)}
{'pi': array([[0.00220553]], dtype=float32), 'h': array([0.1630644], dtype=float32), 'naive_ratio': array([[0.0135255]], dtype=float32), 'dd': tensor([[0.0364]]), 'ret': tensor([[0.3982]], dtype=torch.float64), 'hdice_ratio': tensor([[0.0145]], dtype=torch.float64)}
{'pi': array([[0.84563416]], dtype=float32), 'h': array([0.21039516], dtype=float32), 'naive_ratio': array([[4.019266]], dtype=float32), 'dd': tensor([[1.0000]]), 'ret': tensor([[0.1390]], dtype=torch.float64), 'hdice_ratio': tensor([[0.1390]], dtype=torch.float64)}
{'pi': array([[0.84563416]], dtype=float32), 'h': array([0.4338784], dtype=float32), 'naive_ratio': array([[1.9490119]], dtype=float32), 'dd': tensor([[0.0355]]), 'ret': tensor([[0.3982]], dtype=to

In [15]:
ep2 = [RIGHT, RIGHT, UP, UP, RIGHT]
env.reset()
for action in ep2:
    env.step(action)
env.render()

╒═══╤═══╤═══╤═══╤═══╕
│ . │ * │ . │ G │ * │
├───┼───┼───┼───┼───┤
│ . │ F │ . │ * │ . │
├───┼───┼───┼───┼───┤
│ . │ F │ . │ A │ . │
├───┼───┼───┼───┼───┤
│ . │ F │ . │ F │ . │
├───┼───┼───┼───┼───┤
│ S │ . │ . │ * │ . │
╘═══╧═══╧═══╧═══╧═══╛


In [16]:
curr_state = env.get_state()
print(get_hindsight_ratios(curr_state, np.array([UP]), 69))
print(get_hindsight_ratios(curr_state, np.array([UP]), -100))
print(get_hindsight_ratios(curr_state, np.array([DOWN]), 69))
print(get_hindsight_ratios(curr_state, np.array([DOWN]), -100))

{'pi': array([[0.6149289]], dtype=float32), 'h': array([0.37544358], dtype=float32), 'naive_ratio': array([[1.6378729]], dtype=float32), 'dd': tensor([[0.0438]]), 'ret': tensor([[0.3971]], dtype=torch.float64), 'hdice_ratio': tensor([[0.0174]], dtype=torch.float64)}
{'pi': array([[0.6149289]], dtype=float32), 'h': array([0.13141593], dtype=float32), 'naive_ratio': array([[4.679257]], dtype=float32), 'dd': tensor([[1.0000]]), 'ret': tensor([[0.1465]], dtype=torch.float64), 'hdice_ratio': tensor([[0.1465]], dtype=torch.float64)}
{'pi': array([[0.00899117]], dtype=float32), 'h': array([0.09239677], dtype=float32), 'naive_ratio': array([[0.09731042]], dtype=float32), 'dd': tensor([[0.0427]]), 'ret': tensor([[0.3971]], dtype=torch.float64), 'hdice_ratio': tensor([[0.0170]], dtype=torch.float64)}
{'pi': array([[0.00899117]], dtype=float32), 'h': array([0.28745288], dtype=float32), 'naive_ratio': array([[0.03127876]], dtype=float32), 'dd': tensor([[1.0000]]), 'ret': tensor([[0.1465]], dtype=t

In [23]:
ep1 = [RIGHT, RIGHT, UP, UP, UP]
states, actions, rewards, terminals = get_episode(env, ep1)
returns = calculate_mc_returns(rewards, terminals, args.env.gamma)[-1]

for s, a in zip(states, actions):
    info = get_hindsight_ratios(s, np.array([a]), returns)
    print(info)


{'pi': array([[0.11634082]], dtype=float32), 'h': array([0.2207028], dtype=float32), 'naive_ratio': array([[0.52713794]], dtype=float32), 'dd': tensor([[0.7403]]), 'ret': tensor([[0.3983]], dtype=torch.float64), 'hdice_ratio': tensor([[0.2949]], dtype=torch.float64)}
{'pi': array([[0.04175927]], dtype=float32), 'h': array([0.20499682], dtype=float32), 'naive_ratio': array([[0.20370694]], dtype=float32), 'dd': tensor([[0.6888]]), 'ret': tensor([[0.3681]], dtype=torch.float64), 'hdice_ratio': tensor([[0.2535]], dtype=torch.float64)}
{'pi': array([[0.14581066]], dtype=float32), 'h': array([0.2365405], dtype=float32), 'naive_ratio': array([[0.61643004]], dtype=float32), 'dd': tensor([[0.7417]]), 'ret': tensor([[0.3713]], dtype=torch.float64), 'hdice_ratio': tensor([[0.2754]], dtype=torch.float64)}
{'pi': array([[0.43238544]], dtype=float32), 'h': array([0.27540377], dtype=float32), 'naive_ratio': array([[1.5700057]], dtype=float32), 'dd': tensor([[0.7354]]), 'ret': tensor([[0.3628]], dtype