In [1]:
import sys
sys.path.append("..")
from dataclasses import dataclass
from src.run_dida import get_env, ReplayBuffer, prepare_buffers_for_il, compute_sar, compute_reward_from_disc, mix_anchor_imitator
from gc_datasets.dataset import Dataset
import numpy as np

@dataclass
class CFG:
    env_name = 'Hopper'
    path_to_expert = "/home/m_bobrin/CrossDomainIL/prep_data/hopper/expert_source/trained_expert.npy" #/home/m_bobrin/CrossDomainIL/prep_data/pointumaze/expert_source/trained_expert.npy"
    path_to_random_expert = "/home/m_bobrin/CrossDomainIL/prep_data/hopper/rand_source/random_policy.npy"#/home/m_bobrin/CrossDomainIL/prep_data/pointumaze/rand_source/random_policy.npy"
    path_to_random_target = "/home/m_bobrin/CrossDomainIL/prep_data/hopper/rand_source/random_policy.npy" #/home/m_bobrin/CrossDomainIL/prep_data/pointumaze/rand_target/random_policy.npy"

def prepare_buffers_for_il(cfg, custom_npy: bool=True,
                           clip_to_eps:bool = True, eps:float=1e-5):
    if custom_npy:
        expert_source = np.load(cfg.path_to_expert, allow_pickle=True).item()
        expert_random = np.load(cfg.path_to_random_expert, allow_pickle=True).item()
        target_random = np.load(cfg.path_to_random_target, allow_pickle=True).item()

        expert_source['dones'][-1] = 1
        expert_random['dones'][-1] = 1
        target_random['dones'][-1] = 1
        
        if clip_to_eps:
            lim = 1 - eps
            expert_source['actions'] = np.clip(expert_source['actions'], -lim, lim)
            expert_random['actions'] = np.clip(expert_random['actions'], -lim, lim)
            target_random['actions'] = np.clip(target_random['actions'], -lim, lim)
        
        target_random['observations'] = target_random['observations'].astype(np.float32)
        target_random['next_observations'] = target_random['next_observations'].astype(np.float32)
        
        expert_source['observations'] = expert_source['observations'].astype(np.float32)
        expert_source['next_observations'] = expert_source['next_observations'].astype(np.float32)
        
        expert_random['observations'] = expert_random['observations'].astype(np.float32)
        expert_random['next_observations'] = expert_random['next_observations'].astype(np.float32)
        
        expert_source_ds = Dataset(observations=expert_source['observations'],
                            actions=expert_source['actions'],
                            rewards=expert_source['rewards'],
                            dones_float=expert_source['dones'],
                            masks=1.0 - expert_source['dones'],
                            next_observations=expert_source['next_observations'],
                            size=expert_source['observations'].shape[0])
        
        non_expert_source_dataset = Dataset(observations=expert_random['observations'],
                            actions=expert_random['actions'],
                            rewards=expert_random['rewards'],
                            dones_float=expert_random['dones'],
                            masks=1.0 - expert_random['dones'],
                            next_observations=expert_random['next_observations'],
                            size=expert_random['observations'].shape[0])
        
        expert_source_ds_2 = Dataset(observations=expert_source['observations'],
                            actions=expert_source['actions'],
                            rewards=expert_source['rewards'],
                            dones_float=expert_source['dones'],
                            masks=1.0 - expert_source['dones'],
                            next_observations=expert_source['next_observations'],
                            size=expert_source['observations'].shape[0])
        
        combined_source_ds = expert_source_ds_2.add_data(observations=expert_random['observations'],
                            actions=expert_random['actions'],
                            rewards=expert_random['rewards'],
                            dones_float=expert_random['dones'],
                            masks=1.0 - expert_random['dones'],
                            next_observations=expert_random['next_observations'])
        
        return expert_source_ds, non_expert_source_dataset, combined_source_ds
    
cfg = CFG()
env, eval_env = get_env(cfg.env_name, eval_video_interval=10, num_last_eps_info=100)
source_expert_ds, source_random_ds, combined_source_ds = prepare_buffers_for_il(cfg=cfg)

imitator_buffer = ReplayBuffer(observation_space=eval_env.observation_space,
                                    action_space=eval_env.action_space, capacity=1_000_000)
noisy_expert_buffer = ReplayBuffer(observation_space=eval_env.observation_space,
                                    action_space=eval_env.action_space, capacity=1_000_000)
noisy_expert_buffer.initialize_with_dataset(source_expert_ds)
noisy_expert_buffer.apply_noise()

anchor_buffer = ReplayBuffer(observation_space=eval_env.observation_space,
                                    action_space=eval_env.action_space, capacity=1_000_000)
anchor_buffer.initialize_with_dataset(noisy_expert_buffer)
anchor_buffer.random_shuffle()

  logger.deprecation(


In [2]:
from agents.disc import Discriminator
from agents.dida_agent import DIDA
import jax.numpy as jnp
from jaxrl_m.networks import RelativeRepresentation
import flax.linen as nn
import jax

class Identity(nn.Module):
    @nn.compact
    def __call__(self, x):
        return x
    
rep_dim = 256
D_noisy = Discriminator.create(jnp.ones((rep_dim, )), 2e-5, 5, 10000, 1e-5, hidden_dims=[128, 1])
D_policy = Discriminator.create(jnp.ones((2 * rep_dim, )), 2e-5, 1, 10000, 1e-5, hidden_dims=[128, 1])

dida_agent = DIDA.create(noisy_discr=D_noisy, policy_discr=D_policy, encoders={'source': RelativeRepresentation(ensemble=False),
                                                                 'target': Identity()}, observation_dim=eval_env.observation_space.sample())
noisy_expert_batch, _ = noisy_expert_buffer.sample(2048)
anchor_batch, _ = anchor_buffer.sample(2048)
imitator_batch, imitator_indx = anchor_buffer.sample(2048)

dida_agent = dida_agent.update_noise_discr(imitator_batch, noisy_expert_batch, anchor_batch)
alpha = compute_sar(dida_agent, imitator_batch.observations)
#domain_weight = lmbd * (2 / (1 + np.exp(-10 * (i - 1000) / cfg.max_steps)) - 1)
das_chose = jnp.argsort(jax.nn.softmax(dida_agent.noisy_disc.state(dida_agent.encoder(imitator_batch.observations, method='encode_source')), axis=0))[:int(alpha.item() * 2048)]
das_chosen_imitator, _ = imitator_buffer.sample(indx=imitator_indx[das_chose])
mix_data_batch = mix_anchor_imitator(int(alpha.item() * 2048), das_chosen_imitator, anchor_batch)

2024-07-23 16:57:48.720833: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [5]:
import plotly.graph_objects as go
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
    
tnse = TSNE() 
scaler = MinMaxScaler()

viz_ds = np.concatenate([noisy_expert_buffer.observations[:200], source_expert_ds.observations[:200]], axis=0)
tnse_gaussian = scaler.fit_transform(tnse.fit_transform(viz_ds))

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=tnse_gaussian[:, 0], y=tnse_gaussian[:, 1], mode='markers', marker=dict(color=['red'] * 200 + ['blue'] * source_expert_ds.observations[:200].shape[0]))
)
fig.show()

In [12]:
source_expert_ds.observations.shape[0] // 44877

4877

In [14]:
viz_ds = np.concatenate([noisy_expert_buffer.observations[4877:5000], source_expert_ds.observations[:200]], axis=0)
tnse_gaussian = scaler.fit_transform(tnse.fit_transform(viz_ds))

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=tnse_gaussian[:, 0], y=tnse_gaussian[:, 1], mode='markers', marker=dict(color=['red'] * 123 + ['blue'] * source_expert_ds.observations[:200].shape[0]))
)
fig.show()

In [25]:
viz_ds = np.concatenate([noisy_expert_buffer.observations[85*2:85*3], source_expert_ds.observations], axis=0)
tnse_gaussian = scaler.fit_transform(tnse.fit_transform(viz_ds))

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=tnse_gaussian[:, 0], y=tnse_gaussian[:, 1], mode='markers', marker=dict(color=['red'] * 85 + ['blue'] * source_expert_ds.observations.shape[0]))
)
fig.show()

In [26]:
viz_ds = np.concatenate([noisy_expert_buffer.observations[85*3:], source_expert_ds.observations], axis=0)
tnse_gaussian = scaler.fit_transform(tnse.fit_transform(viz_ds))

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=tnse_gaussian[:, 0], y=tnse_gaussian[:, 1], mode='markers', marker=dict(color=['red'] * 85 + ['blue'] * source_expert_ds.observations.shape[0]))
)
fig.show()