In [1]:
import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append("..")

import os
os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
os.environ['CUDA_VISIBLE_DEVICES']='0,1'

import gym
import d4rl

import numpy as np
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from matplotlib import patches

import equinox as eqx
import jax
import jax.numpy as jnp
import functools

from tqdm.auto import tqdm
from jaxrl_m.common import TrainStateEQX
from src.agents.iql_equinox import GaussianPolicy, GaussianIntentPolicy

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot, sinkhorn_divergence
from ott.solvers.linear import implicit_differentiation as imp_diff

import optax

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))
def eval_ensemble_psi(ensemble, s):
    return eqx.filter_vmap(ensemble.psi_net)(s)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))
def eval_ensemble_phi(ensemble, s):
    return eqx.filter_vmap(ensemble.phi_net)(s)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))
def eval_ensemble_icvf_viz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.classic_icvf_initial)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g, z), g - dim 29, z - dim 256
def eval_ensemble_icvf_latent_z(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.classic_icvf)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g ,z ), g, z - dim 256
def eval_ensemble_icvf_latent_zz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.icvf_zz)(s, g, z)
    
@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))
def eval_ensemble_icvf_latent_zzz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.icvf_zzz)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0, z=None))
def eqx_get_state_traj(ensemble, s, z):
    '''
    Function to compute pairwise distance between two trajectories
    '''
    s = jnp.tile(s, (z.shape[0], 1))
    return eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, s, z, z)

@eqx.filter_jit
def get_gcvalue(agent, s, g, z):
    v_sgz_1, v_sgz_2 = eval_ensemble_icvf_viz(agent.value_learner.model, s, g, z)
    return (v_sgz_1 + v_sgz_2) / 2

def get_v_gz(agent, initial_state, target_goal, observations):
    initial_state = jnp.tile(initial_state, (observations.shape[0], 1))
    target_goal = jnp.tile(target_goal, (observations.shape[0], 1))
    return -1 * get_gcvalue(agent, initial_state, observations, target_goal)
    
def get_v_zz(agent, goal, observations):
    goal = jnp.tile(goal, (observations.shape[0], 1))
    return get_gcvalue(agent, observations, goal, goal)

@eqx.filter_vmap(in_axes=dict(agent=None, obs=None, goal=0))
def get_v_zz_heatmap(agent, obs, goal): # goal - traj
    goal = jnp.tile(goal, (obs.shape[0], 1))
    return get_gcvalue(agent, obs, goal, goal)

%matplotlib inline
%load_ext autoreload
%autoreload 2

pybullet build time: May 20 2022 19:45:31


In [3]:
from src.gc_dataset import GCSDataset
from utils.ds_builder import setup_datasets

env, expert_ds, agent_ds, agent_mean_states, agent_std_states = setup_datasets(expert_env_name="halfcheetah-medium-expert-v2",
                                          agent_env_name="halfcheetah-medium-expert-v2", expert_num=1,
                                          normalize_agent_states=False)

gcsds_params = GCSDataset.get_default_config()
gc_expert_dataset = GCSDataset(expert_ds, **gcsds_params)
gc_agent_dataset = GCSDataset(agent_ds, **gcsds_params)

expert_trajectory = gc_expert_dataset.get_expert_traj()['observations']

load datafile:   0%|          | 0/9 [00:00<?, ?it/s]

load datafile: 100%|██████████| 9/9 [00:02<00:00,  3.98it/s]


  0%|          | 0/1998000 [00:00<?, ?it/s]

Expert returns [11239.283746674657], mean 11239.283746674657


load datafile: 100%|██████████| 9/9 [00:02<00:00,  4.13it/s]


Number of terminal states: 1
Number of terminal states: 2001


In [4]:
%cd ..
from src.agents import icvf
icvf_model = icvf.create_eqx_learner(seed=228,
                                     observations=expert_ds.dataset_dict['observations'][0],
                                     hidden_dims=[256, 256],
                                     pretrained_folder="halfcheetah-medium-expert",
                                     load_pretrained_icvf=True)

/home/nazar/projects/AILOT
Extra kwargs: {}


In [41]:
from abc import ABC, abstractmethod
from jax.numpy import ndarray
from ott.geometry import costs
from ott.math import utils as mu


class OTRewardsExpert:

    def __init__(
        self, expert_traj = None, expert_z = None
    ):
        self.expert_states = expert_traj
        self.expert_z = expert_z
        if expert_z is None:
            self.expert_z = eval_ensemble_psi(icvf_model.value_learner.model, expert_traj).mean(axis=0)
        self.sub_steps = 5

    def make_subs(self, z, sub_steps):
        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)
        return jax.tree_map(lambda arr: arr[sub_indx], z)
    

    @eqx.filter_jit
    def get_z_and_start_index(self, obs):
        # obs - trajectory
        z = eval_ensemble_psi(icvf_model.value_learner.model, obs).mean(axis=0)
        diff = jnp.linalg.norm(z[0][jnp.newaxis,] - self.expert_z, axis=-1) #eqx_get_state_traj(icvf_model.value_learner.model, z[0][None], self.expert_z).mean(1)#z[0][jnp.newaxis,] - self.expert_z
        i_min = jnp.argmin(diff)#jnp.argmin((diff**2).sum(-1)).squeeze()
        return z, i_min, diff

    def compute_rewards(
        self,
        dataset
    ):
        i0 = 0
        rewards = []
        observations = dataset.dataset_dict['observations']
        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()
        
        for i1 in tqdm(range(len(episode_starts))):
            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])
            ri = self.compute_rewards_one_episode(zi, self.expert_z[start_index:])
            #print(eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, zi[0][None], self.expert_z[start_index][5][None], self.expert_z[start_index][5][None]).mean(0))
            rewards.append(jax.device_get(ri))
                  
        return np.concatenate(rewards)#, selected_index

    @eqx.filter_jit
    def compute_rewards_one_episode(
        self, episode_obs, expert_obs
    ):

        za_1 = episode_obs
        za_2 = self.make_subs(za_1, self.sub_steps)
        x = jnp.concatenate([za_1, za_2], axis=1)

        ze_1 = expert_obs
        ze_2 = self.make_subs(ze_1, self.sub_steps)
        y = jnp.concatenate([ze_1, ze_2], axis=1)
        
        geom = pointcloud.PointCloud(x, y, epsilon=0.001)
        ot_prob = linear_problem.LinearProblem(geom)
        solver = sinkhorn.Sinkhorn(max_iterations=300, use_danskin=True)

        ot_sink = solver(ot_prob)
        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)
        rewards = -transp_cost * 100

        return rewards
        




In [8]:
expert = OTRewardsExpert(expert_trajectory)
rewards = expert.compute_rewards(gc_agent_dataset.dataset)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [43]:
from src.dataset import Dataset

class ExpRewardsScaler:
    def init(self, rewards: np.ndarray):
        self.min = np.quantile(np.abs(rewards).reshape(-1), 0.0)
        self.max = np.quantile(np.abs(rewards).reshape(-1), 0.95)

    def scale(self, rewards: np.ndarray):
        # From paper
        return 2 * np.exp(3 * rewards)


def get_subs(dataset: GCSDataset, add_steps: int):
    terminal_locs = dataset.terminal_locs
    indx = np.arange(dataset.dataset.dataset_dict['observations'].shape[0])
    final_state_indx = terminal_locs[np.searchsorted(terminal_locs, indx)] 
    way_indx = np.minimum(indx + add_steps, final_state_indx)
    subs = jax.tree_map(lambda arr: arr[way_indx], dataset.dataset.dataset_dict['observations'])
    return subs

sq_rewards = -jnp.sqrt(-rewards+0.0001)
scaler = ExpRewardsScaler()
#scaler.init(sq_rewards)
scaled_rewards = scaler.scale(rewards).astype(np.float32)

## Apply iql scaling
from utils.ds_builder import load_trajectories
    
offline_traj = load_trajectories(env.spec.id, scaled_rewards)
    
def compute_iql_reward_scale(trajs):
    """Rescale rewards based on max/min from the dataset.
    This is also used in the original IQL implementation.
    """
    trajs = trajs.copy()
    
    def compute_returns(tr):
        return sum([step[2] for step in tr])
    
    trajs.sort(key=compute_returns)
    reward_scale = 1000.0 / (
      compute_returns(trajs[-1]) - compute_returns(trajs[0]))
    return reward_scale
    

ds = gc_agent_dataset.dataset.dataset_dict
episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()
data_with_ot_rewards = Dataset(
    {'observations': np.concatenate([ds['observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
    'next_observations': np.concatenate([ds['next_observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
    'actions': np.concatenate([ds['actions'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
    'rewards':scaled_rewards * compute_iql_reward_scale(offline_traj),
    'masks': 1.0 - np.concatenate([ds['dones'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
    })#[scaled_rewards > r_min]})

load datafile:   0%|          | 0/9 [00:00<?, ?it/s]

load datafile: 100%|██████████| 9/9 [00:02<00:00,  3.99it/s]


  0%|          | 0/1998000 [00:00<?, ?it/s]

In [44]:
from src.agents.iql_flax.common import Batch
from src.agents.iql_flax.learner import Learner
from src.agents.iql_flax.evaluation import evaluate

config={'max_steps': 1_000_000, 'seed': 10, 'expectile':0.7, 'discount': 0.99, 'temperature': 3}

iql_agent_ot = Learner(
        config["seed"],
        env.observation_space.sample()[np.newaxis],
        env.action_space.sample()[np.newaxis],
        max_steps=config["max_steps"],
        expectile=config["expectile"],
        discount=config["discount"],
        temperature=config["temperature"])

pbar = tqdm(range(1_000_000 + 1))
# expert = OTRewardsExpert(expert_trajectory)

for i in pbar:
    sample = data_with_ot_rewards.sample(256)
    batch = Batch(
        observations=sample["observations"],
        next_observations=sample["next_observations"],
        actions = sample['actions'],
        rewards= sample["rewards"],
        masks= sample["masks"]
    )
    update_info = iql_agent_ot.update(batch)
    update_info['adv'] = None
    if i % 50_000 == 0 and i > 0:
        eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)[0]
        print(eval_stats)
        eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100
        # wandb.log({f"Eval/{key}": value for key, value in eval_stats.items()})
        pbar.set_postfix(update_info)
    if i % 2000 == 0:
        # wandb.log({f"Training/{key}": value for key, value in update_info.items()})
        pbar.set_postfix(update_info)

  0%|          | 0/1000001 [00:00<?, ?it/s]

{'return': 4970.7735141067915, 'length': 1000.0}
{'return': 5246.734164696242, 'length': 1000.0}
{'return': 6973.7251702966005, 'length': 1000.0}
{'return': 8053.332608825057, 'length': 1000.0}
{'return': 5760.203850621122, 'length': 1000.0}
{'return': 6794.748980632254, 'length': 1000.0}
{'return': 6682.4577897749505, 'length': 1000.0}
{'return': 5404.03986462656, 'length': 1000.0}
{'return': 7782.59329851099, 'length': 1000.0}
{'return': 7187.214399819459, 'length': 1000.0}
{'return': 7103.503169974261, 'length': 1000.0}
{'return': 6518.621234749199, 'length': 1000.0}
{'return': 7400.920529688854, 'length': 1000.0}
{'return': 7193.68911741444, 'length': 1000.0}
{'return': 6669.739124283704, 'length': 1000.0}
{'return': 7310.7417039440725, 'length': 1000.0}
{'return': 8207.397912104232, 'length': 1000.0}
{'return': 7619.996945590482, 'length': 1000.0}
{'return': 7062.8304628408205, 'length': 1000.0}
{'return': 9379.701922264492, 'length': 1000.0}


In [14]:
eval_stats = evaluate(iql_agent_ot, env, num_episodes=50)[0]
eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100
print(eval_stats)

{'return': 92.64012068963244, 'length': 1000.0}


In [45]:
from gymnasium.utils import save_video

frames=[]
i = 0
num_episodes = 1
all_reward = []
key = jax.random.PRNGKey(42)

for i in range(num_episodes):
    episode_reward = 0
    key, sample_key = jax.random.split(key, 2)
    obs = env.reset()
    done = False
    while not done:
        key, sample_key = jax.random.split(sample_key, 2)
        action = jax.device_get(iql_agent_ot.sample_actions(obs, temperature=0.0))
        obs, reward, done ,_ = env.step(action)
        os.environ['CUDA_VISIBLE_DEVICES']='1'
        frames.append(env.render(mode='rgb_array'))
        os.environ['CUDA_VISIBLE_DEVICES']='0,1'
        episode_reward += reward
    all_reward.append(episode_reward)
    print(episode_reward)
save_video.save_video(frames, video_folder='.', fps=20)

4841.356928073171
Moviepy - Building video /home/nazar/projects/AILOT/rl-video-episode-0.mp4.
Moviepy - Writing video /home/nazar/projects/AILOT/rl-video-episode-0.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /home/nazar/projects/AILOT/rl-video-episode-0.mp4


In [32]:
hopper_env, hopper_expert_ds, hopper_agent_ds, _, _ = setup_datasets(
    expert_env_name="hopper-medium-expert-v2",
    agent_env_name="hopper-medium-expert-v2", expert_num=1,
    normalize_agent_states=False)

hopper_gcsds_params = GCSDataset.get_default_config()
hopper_gc_expert_dataset = GCSDataset(hopper_expert_ds, **hopper_gcsds_params)
hopper_gc_agent_dataset = GCSDataset(hopper_agent_ds, **hopper_gcsds_params)

hopper_expert_trajectory = hopper_gc_expert_dataset.get_expert_traj()['observations']

load datafile: 100%|██████████| 9/9 [00:01<00:00,  6.36it/s]


  0%|          | 0/1998966 [00:00<?, ?it/s]

Expert returns [3753.886583685875], mean 3753.886583685875


load datafile: 100%|██████████| 9/9 [00:01<00:00,  6.35it/s]


Number of terminal states: 1
Number of terminal states: 3215


In [19]:
%cd /home/nazar/projects/AILOT
from src.agents import icvf
hopper_icvf_model = icvf.create_eqx_learner(seed=228,
                                     observations=hopper_expert_ds.dataset_dict['observations'][0],
                                     hidden_dims=[256, 256],
                                     pretrained_folder="hopper-medium-expert",
                                     load_pretrained_icvf=True)

/home/nazar/projects/AILOT
Extra kwargs: {}


In [27]:
@eqx.filter_jit
def crossenv_expert_projection(
    expert_obs, agent_obs, other_agent_obs, icvf_model_1, icvf_model_2 
):
    
    Ne = expert_obs.shape[0]
    cat_obs = jnp.concatenate([expert_obs, agent_obs], axis=0)
    z_1 = eval_ensemble_psi(icvf_model_1, cat_obs).mean(axis=0)
    z_2 = eval_ensemble_psi(icvf_model_2, other_agent_obs).mean(axis=0)

    geom = pointcloud.PointCloud(z_1, z_2, epsilon=0.001)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn(max_iterations=500, use_danskin=True)

    ot_sink = solver(ot_prob)
    P = ot_sink.matrix[:Ne, :] 
    
    expert_proj = P @ z_2

    return expert_proj

In [37]:
cheetah_obs = gc_agent_dataset.dataset.dataset_dict["observations"][:30_000]
hopper_obs = hopper_gc_agent_dataset.dataset.dataset_dict["observations"][:30_000]

In [33]:
hopper_expert_trajectory.shape

(999, 11)

In [38]:
hopper_expert_projection = crossenv_expert_projection(
    hopper_expert_trajectory, 
    hopper_obs, 
    cheetah_obs, 
    hopper_icvf_model.value_learner.model, 
    icvf_model.value_learner.model)

In [39]:
hopper_expert_projection.shape

(999, 256)

In [42]:
expert = OTRewardsExpert(expert_z=hopper_expert_projection)
rewards = expert.compute_rewards(gc_agent_dataset.dataset)

  0%|          | 0/2000 [00:00<?, ?it/s]