In [2]:
import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append("..")

import os
os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
os.environ['CUDA_VISIBLE_DEVICES']='0,1'

import gym
import d4rl

import numpy as np
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from matplotlib import patches

import equinox as eqx
import jax
import jax.numpy as jnp
import functools

from tqdm.auto import tqdm
from jaxrl_m.common import TrainStateEQX
from src.agents.iql_equinox import GaussianPolicy, GaussianIntentPolicy

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot, sinkhorn_divergence
from ott.solvers.linear import implicit_differentiation as imp_diff

import optax

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))
def eval_ensemble_psi(ensemble, s):
    return eqx.filter_vmap(ensemble.psi_net)(s)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))
def eval_ensemble_phi(ensemble, s):
    return eqx.filter_vmap(ensemble.phi_net)(s)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))
def eval_ensemble_icvf_viz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.classic_icvf_initial)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g, z), g - dim 29, z - dim 256
def eval_ensemble_icvf_latent_z(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.classic_icvf)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g ,z ), g, z - dim 256
def eval_ensemble_icvf_latent_zz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.icvf_zz)(s, g, z)
    
@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))
def eval_ensemble_icvf_latent_zzz(ensemble, s, g, z):
    return eqx.filter_vmap(ensemble.icvf_zzz)(s, g, z)

@eqx.filter_vmap(in_axes=dict(ensemble=None, s=0, z=None))
def eqx_get_state_traj(ensemble, s, z):
    '''
    Function to compute pairwise distance between two trajectories
    '''
    s = jnp.tile(s, (z.shape[0], 1))
    return eval_ensemble_icvf_latent_zzz(icvf_model.value_learner.model, s, z, z)

@eqx.filter_jit
def get_gcvalue(agent, s, g, z):
    v_sgz_1, v_sgz_2 = eval_ensemble_icvf_viz(agent.value_learner.model, s, g, z)
    return (v_sgz_1 + v_sgz_2) / 2

def get_v_gz(agent, initial_state, target_goal, observations):
    initial_state = jnp.tile(initial_state, (observations.shape[0], 1))
    target_goal = jnp.tile(target_goal, (observations.shape[0], 1))
    return -1 * get_gcvalue(agent, initial_state, observations, target_goal)
    
def get_v_zz(agent, goal, observations):
    goal = jnp.tile(goal, (observations.shape[0], 1))
    return get_gcvalue(agent, observations, goal, goal)

@eqx.filter_vmap(in_axes=dict(agent=None, obs=None, goal=0))
def get_v_zz_heatmap(agent, obs, goal): # goal - traj
    goal = jnp.tile(goal, (obs.shape[0], 1))
    return get_gcvalue(agent, obs, goal, goal)

%matplotlib inline
%load_ext autoreload
%autoreload 2

pybullet build time: May 20 2022 19:45:31


In [3]:
from src.gc_dataset import GCSDataset
from utils.ds_builder import setup_datasets

class EnvData:
    def __init__(self, env, agent_dataset, expert_dataset, expert_trajectory) -> None:
        self.env = env
        self.agent_dataset = agent_dataset
        self.expert_dataset = expert_dataset
        self.expert_trajectory = expert_trajectory

    @eqx.filter_jit
    def expert_intents(self, icvf_value_model):
        return eval_ensemble_psi(icvf_value_model, self.expert_trajectory).mean(axis=0)

def load_env_data(expert_env_name, agent_env_name, expert_num):

    env, expert_ds, agent_ds, agent_mean_states, agent_std_states = setup_datasets(
        expert_env_name=expert_env_name,
        agent_env_name=agent_env_name, 
        expert_num=expert_num,
        normalize_agent_states=False)

    gcsds_params = GCSDataset.get_default_config()
    gc_expert_dataset = GCSDataset(expert_ds, **gcsds_params)
    gc_agent_dataset = GCSDataset(agent_ds, **gcsds_params)

    expert_trajectory = gc_expert_dataset.get_expert_traj()['observations']

    return EnvData(env, gc_agent_dataset.dataset, gc_expert_dataset.dataset, expert_trajectory)

halfcheetah_data = load_env_data("halfcheetah-expert-v2", "halfcheetah-medium-v2", expert_num=1)
hopper_data = load_env_data("hopper-expert-v2", "hopper-medium-v2", expert_num=1)

load datafile: 100%|██████████| 21/21 [00:01<00:00, 17.96it/s]


  0%|          | 0/999000 [00:00<?, ?it/s]

Expert returns [11239.283746674657], mean 11239.283746674657


load datafile: 100%|██████████| 21/21 [00:01<00:00, 18.50it/s]


Number of terminal states: 1
Number of terminal states: 1001


load datafile: 100%|██████████| 21/21 [00:00<00:00, 29.22it/s]


  0%|          | 0/999061 [00:00<?, ?it/s]

Expert returns [3753.886583685875], mean 3753.886583685875


load datafile: 100%|██████████| 21/21 [00:00<00:00, 29.63it/s]


Number of terminal states: 1
Number of terminal states: 2188


In [4]:
%cd /home/nazar/projects/AILOT
from src.agents import icvf
halfcheetah_icvf_model = icvf.create_eqx_learner(seed=228,
                                     observations=halfcheetah_data.expert_dataset.dataset_dict['observations'][0],
                                     hidden_dims=[256, 256],
                                     pretrained_folder="halfcheetah-medium",
                                     load_pretrained_icvf=True)

hopper_icvf_model = icvf.create_eqx_learner(seed=228,
                                     observations=hopper_data.expert_dataset.dataset_dict['observations'][0],
                                     hidden_dims=[256, 256],
                                     pretrained_folder="hopper-medium",
                                     load_pretrained_icvf=True)

/home/nazar/projects/AILOT
Extra kwargs: {}
Extra kwargs: {}


In [5]:
from abc import ABC, abstractmethod
from jax.numpy import ndarray
from ott.geometry import costs
from ott.math import utils as mu


class OTRewardsExpert:

    def __init__(
        self, icvf_model, expert_z
    ):
        self.icvf_model = icvf_model
        self.expert_z = expert_z
        self.sub_steps = 1

    def make_subs(self, z, sub_steps):
        sub_indx = jnp.minimum(jnp.arange(0, z.shape[0]) + sub_steps, z.shape[0] - 1)
        return jax.tree_map(lambda arr: arr[sub_indx], z)
    

    @eqx.filter_jit
    def get_z_and_start_index(self, obs):
        # obs - trajectory
        z = eval_ensemble_psi(self.icvf_model.value_learner.model, obs).mean(axis=0)
        diff = jnp.linalg.norm(z[0][jnp.newaxis,] - self.expert_z, axis=-1) 
        i_min = jnp.argmin(diff)
        return z, i_min, diff

    def compute_rewards(
        self,
        dataset
    ):
        i0 = 0
        rewards = []
        observations = dataset.dataset_dict['observations']
        episode_starts, episode_ends, _ = dataset._trajectory_boundaries_and_returns()
        
        for i1 in tqdm(range(len(episode_starts))):
            zi, start_index, diff = self.get_z_and_start_index(observations[episode_starts[i1]:episode_ends[i1]])
            ri = self.compute_rewards_one_episode(zi, self.expert_z[start_index:])
            rewards.append(jax.device_get(ri))
                  
        return np.concatenate(rewards)

    @eqx.filter_jit
    def compute_rewards_one_episode(
        self, episode_obs, expert_obs
    ):

        za_1 = episode_obs
        za_2 = self.make_subs(za_1, self.sub_steps)
        x = jnp.concatenate([za_1, za_2], axis=1)

        ze_1 = expert_obs
        ze_2 = self.make_subs(ze_1, self.sub_steps)
        y = jnp.concatenate([ze_1, ze_2], axis=1)
        
        geom = pointcloud.PointCloud(x, y, epsilon=0.001)
        ot_prob = linear_problem.LinearProblem(geom)
        solver = sinkhorn.Sinkhorn(max_iterations=500, use_danskin=True)

        ot_sink = solver(ot_prob)
        transp_cost = jnp.sum(ot_sink.matrix * geom.cost_matrix, axis=1)
        rewards = -transp_cost * 100

        return rewards
        


In [6]:
from src.dataset import Dataset

class ExpRewardsScaler:

    def __init__(self, a=2.0, b=1.0):
        self.a = a
        self.b = b

    def scale(self, rewards: np.ndarray):
        # From paper
        return self.a * np.exp(rewards * self.b)


In [7]:

def compute_iql_reward_scale(dataset, rewards):

    episode_starts, episode_ends, _ = dataset._trajectory_boundaries_and_returns()
    episode_lens = [i2 - i1 for i1, i2 in zip(episode_starts, episode_ends)]
    episode_sep = np.cumsum([0] + episode_lens)
    
    r_sum = [rewards[i1:i2].sum() for i1, i2 in zip(episode_sep[:-1],  episode_sep[1:])]
    reward_scale = 1000.0 / (max(r_sum) - min(r_sum))
    return reward_scale

def make_train_data(env_data: EnvData, rewards, scaler: ExpRewardsScaler):

    scaled_rewards = scaler.scale(rewards).astype(np.float32)
    print(rewards.min(), rewards.mean(), rewards.max())
    print(scaled_rewards.min(), scaled_rewards.mean(), scaled_rewards.max())

    ds = env_data.agent_dataset.dataset_dict
    episode_starts, episode_ends, _ = env_data.agent_dataset._trajectory_boundaries_and_returns()

    data_with_ot_rewards = Dataset(
        {'observations': np.concatenate([ds['observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
        'next_observations': np.concatenate([ds['next_observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
        'actions': np.concatenate([ds['actions'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
        'rewards':scaled_rewards * compute_iql_reward_scale(env_data.agent_dataset, scaled_rewards),
        'masks': 1.0 - np.concatenate([ds['dones'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
        })
    return data_with_ot_rewards

In [8]:
from src.agents.iql_flax.common import Batch
from src.agents.iql_flax.learner import Learner
from src.agents.iql_flax.evaluation import evaluate

def training(env_data: EnvData, data_with_ot_rewards: Dataset, 
             max_steps=1_000_000, seed=10, expectile=0.7, discount=0.99, temperature=3):

    iql_agent_ot = Learner(
            seed,
            env_data.env.observation_space.sample()[np.newaxis],
            env_data.env.action_space.sample()[np.newaxis],
            max_steps=max_steps,
            expectile=expectile,
            discount=discount,
            temperature=temperature)

    pbar = tqdm(range(max_steps + 1))

    for i in pbar:
        sample = data_with_ot_rewards.sample(256)
        batch = Batch(
            observations=sample["observations"],
            next_observations=sample["next_observations"],
            actions = sample['actions'],
            rewards= sample["rewards"],
            masks= sample["masks"]
        )
        update_info = iql_agent_ot.update(batch)
        update_info['adv'] = None

        if i % 50_000 == 0 and i > 0:
            eval_stats = evaluate(iql_agent_ot, env_data.env, num_episodes=10)[0]
            eval_stats['return'] = env_data.env.get_normalized_score(eval_stats['return'])*100
            print(eval_stats)
            pbar.set_postfix(update_info)
        if i % 2000 == 0:
            pbar.set_postfix(update_info)

    return iql_agent_ot 

In [9]:
from gymnasium.utils import save_video

def eval_agent(iql_agent_ot, env):

    eval_stats = evaluate(iql_agent_ot, env, num_episodes=50)[0]
    eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100
    print(eval_stats)

    frames=[]
    i = 0
    num_episodes = 1
    all_reward = []
    key = jax.random.PRNGKey(42)

    for i in range(num_episodes):
        episode_reward = 0
        key, sample_key = jax.random.split(key, 2)
        obs = env.reset()
        done = False
        k = 0
        while k < 2000:
            key, sample_key = jax.random.split(sample_key, 2)
            action = jax.device_get(iql_agent_ot.sample_actions(obs, temperature=0.0))
            obs, reward, done ,_ = env.step(action)
            os.environ['CUDA_VISIBLE_DEVICES']='1'
            frames.append(env.render(mode='rgb_array'))
            os.environ['CUDA_VISIBLE_DEVICES']='0,1'
            episode_reward += reward
            k+=1
        all_reward.append(episode_reward)
        print(episode_reward)
    save_video.save_video(frames, video_folder='.', fps=50)

In [10]:
@eqx.filter_jit
def crossenv_expert_projection(
    expert_obs, agent_obs, other_agent_obs, icvf_model_1, icvf_model_2 
):
    
    Ne = expert_obs.shape[0]
    cat_obs = jnp.concatenate([expert_obs, agent_obs], axis=0)
    z_1 = eval_ensemble_psi(icvf_model_1, cat_obs).mean(axis=0)
    z_2 = eval_ensemble_psi(icvf_model_2, other_agent_obs).mean(axis=0)

    geom = pointcloud.PointCloud(z_1, z_2, epsilon=0.001)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn(max_iterations=500, use_danskin=True)

    ot_sink = solver(ot_prob)
    P = ot_sink.matrix[:Ne, :] 
    # P = P ** 2
    P = P / P.sum(1, keepdims=True)
    
    expert_proj = P @ z_2

    return P, expert_proj

expert: halfcheetah, agent: hopper

In [11]:
_, ch_z = crossenv_expert_projection(
    halfcheetah_data.expert_trajectory, 
    halfcheetah_data.agent_dataset.dataset_dict["observations"][1000:10000],
    hopper_data.agent_dataset.dataset_dict["observations"][:10000],
    halfcheetah_icvf_model.value_learner.model,
    hopper_icvf_model.value_learner.model
)

expert = OTRewardsExpert(hopper_icvf_model, expert_z=ch_z)
rewards = expert.compute_rewards(hopper_data.agent_dataset) 


  0%|          | 0/2187 [00:00<?, ?it/s]

In [13]:
scaler = ExpRewardsScaler(a=2, b=3)
train_data = make_train_data(hopper_data, rewards, scaler)
iql_agent_ot = training(hopper_data, train_data)
print("eval")
eval_agent(iql_agent_ot, hopper_data.env)

-178.75755 -13.217977 -0.011129665
0.0 0.014802219 1.9343246


  0%|          | 0/1000001 [00:00<?, ?it/s]

{'return': 38.82416872275892, 'length': 392.8}
{'return': 49.755367148302575, 'length': 505.7}
{'return': 46.706282574446696, 'length': 463.2}
{'return': 70.35090313633712, 'length': 717.0}
{'return': 60.658142989542185, 'length': 607.2}
{'return': 81.90848792791941, 'length': 815.8}
{'return': 71.75185532385508, 'length': 707.8}
{'return': 75.36070078332473, 'length': 751.5}
{'return': 65.40694374382271, 'length': 646.6}
{'return': 59.86725287751574, 'length': 592.1}
{'return': 61.54296942389733, 'length': 604.8}
{'return': 69.85132003600141, 'length': 688.2}
{'return': 59.981328877361726, 'length': 594.9}
{'return': 64.85076510196555, 'length': 640.6}
{'return': 63.28583581025741, 'length': 623.3}
{'return': 57.05454035963865, 'length': 564.4}
{'return': 60.55414480132802, 'length': 600.2}
{'return': 61.95514110484371, 'length': 613.1}
{'return': 70.81534563766094, 'length': 696.8}
{'return': 61.910937713339834, 'length': 610.8}
eval
{'return': 63.25658803436569, 'length': 624.84}
Fo

                                                                 

Moviepy - Done !
Moviepy - video ready /home/nazar/projects/AILOT/rl-video-episode-0.mp4




In [14]:
eval_agent(iql_agent_ot, hopper_data.env)

{'return': 64.29731543022103, 'length': 634.28}
3322.3044449708113
Moviepy - Building video /home/nazar/projects/AILOT/rl-video-episode-0.mp4.
Moviepy - Writing video /home/nazar/projects/AILOT/rl-video-episode-0.mp4



                                                                 

Moviepy - Done !
Moviepy - video ready /home/nazar/projects/AILOT/rl-video-episode-0.mp4


expert: hopper, agent: halfcheetah

In [166]:
_, hopper_z = crossenv_expert_projection(
    hopper_data.expert_trajectory, 
    hopper_data.agent_dataset.dataset_dict["observations"][1000:10000],
    halfcheetah_data.agent_dataset.dataset_dict["observations"][:10000],
    hopper_icvf_model.value_learner.model,
    halfcheetah_icvf_model.value_learner.model
)

expert = OTRewardsExpert(halfcheetah_icvf_model, expert_z=hopper_z)
rewards = expert.compute_rewards(halfcheetah_data.agent_dataset) 


  0%|          | 0/1000 [00:00<?, ?it/s]

In [169]:
scaler = ExpRewardsScaler(a=2, b=1)
train_data = make_train_data(halfcheetah_data, rewards, scaler)
iql_agent_ot = training(halfcheetah_data, train_data)
print("eval")
eval_agent(iql_agent_ot, halfcheetah_data.env)

-61.382805 -3.7067518 -0.11133348
4.39356e-27 0.14845993 1.7892808


  0%|          | 0/1000001 [00:00<?, ?it/s]

{'return': 37.34827159823951, 'length': 1000.0}
{'return': 40.48865091324573, 'length': 1000.0}
{'return': 41.7990364147947, 'length': 1000.0}
{'return': 41.543602254839406, 'length': 1000.0}
{'return': 42.651300590614674, 'length': 1000.0}
{'return': 41.61175193042423, 'length': 1000.0}
{'return': 43.4711161827794, 'length': 1000.0}
{'return': 42.20855050684148, 'length': 1000.0}
{'return': 39.29940227644949, 'length': 1000.0}
{'return': 42.41720829914906, 'length': 1000.0}
{'return': 43.93239964049964, 'length': 1000.0}
{'return': 39.323755298910555, 'length': 1000.0}
{'return': 42.69523902462256, 'length': 1000.0}
{'return': 42.62258536438095, 'length': 1000.0}
{'return': 42.952596216218026, 'length': 1000.0}
{'return': 43.2527679022389, 'length': 1000.0}
{'return': 43.63621280565453, 'length': 1000.0}
{'return': 43.68034333090674, 'length': 1000.0}
{'return': 43.047491575085346, 'length': 1000.0}
{'return': 42.76063363612712, 'length': 1000.0}
eval
{'return': 42.91562946183661, 'le

                                                                 

Moviepy - Done !
Moviepy - video ready /home/nazar/projects/AILOT/rl-video-episode-0.mp4


In [180]:
walker_data = load_env_data("walker2d-expert-v2", "walker2d-medium-v2", expert_num=1)

walker_icvf_model = icvf.create_eqx_learner(seed=228,
                                     observations=walker_data.expert_dataset.dataset_dict['observations'][0],
                                     hidden_dims=[256, 256],
                                     pretrained_folder="walker2d-medium",
                                     load_pretrained_icvf=True)

load datafile:   0%|          | 0/21 [00:00<?, ?it/s]

load datafile: 100%|██████████| 21/21 [00:01<00:00, 17.75it/s]


  0%|          | 0/999000 [00:00<?, ?it/s]

Expert returns [5006.127595229074], mean 5006.127595229074


load datafile: 100%|██████████| 21/21 [00:01<00:00, 17.68it/s]


Number of terminal states: 1
Number of terminal states: 1192
Extra kwargs: {}


In [1]:
_, hopper_z = crossenv_expert_projection(
    hopper_data.expert_trajectory, 
    hopper_data.agent_dataset.dataset_dict["observations"][1000:10000],
    walker_data.agent_dataset.dataset_dict["observations"][:10000],
    hopper_icvf_model.value_learner.model,
    walker_icvf_model.value_learner.model
)

# w_z = walker_data.expert_intents(walker_icvf_model.value_learner.model)

expert = OTRewardsExpert(walker_icvf_model, expert_z=hopper_z)
rewards = expert.compute_rewards(walker_data.agent_dataset) 

NameError: name 'crossenv_expert_projection' is not defined

In [188]:
scaler = ExpRewardsScaler(a=2, b=3)
train_data = make_train_data(walker_data, rewards, scaler)
iql_agent_ot = training(walker_data, train_data)
print("eval")
eval_agent(iql_agent_ot, walker_data.env)

-4977.6567 -6.78997 -0.047423262
0.0 0.02343731 1.7347745


  0%|          | 0/1000001 [00:00<?, ?it/s]

{'return': 61.57563718817133, 'length': 754.8}
{'return': 67.32047210715608, 'length': 820.3}
{'return': 70.70063222285066, 'length': 860.6}
{'return': 83.32679048477006, 'length': 1000.0}
{'return': 82.1758230550337, 'length': 977.6}
{'return': 76.11041875571046, 'length': 910.6}
{'return': 83.15975493791889, 'length': 971.0}
{'return': 78.42602295410624, 'length': 913.2}
{'return': 82.86533033002696, 'length': 988.4}
{'return': 79.04452918536423, 'length': 917.2}
{'return': 84.36580700590643, 'length': 1000.0}
{'return': 76.57034049739816, 'length': 880.6}
{'return': 81.02505961178433, 'length': 948.6}
{'return': 71.38968155521458, 'length': 845.0}
{'return': 78.95277800020418, 'length': 910.1}


KeyboardInterrupt: 