In [None]:
import composuite
from diffusion.utils import *
from corl.algorithms.offline.td3_bc import *
from corl.shared.buffer import *
from corl.shared.logger import *

from diffusion.utils import *
from collections import defaultdict
import composuite
from sklearn.preprocessing import StandardScaler
import umap
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns

def transitions_dataset_with_timesteps(dataset):
    """
    https://github.com/Farama-Foundation/D4RL/blob/89141a689b0353b0dac3da5cba60da4b1b16254d/d4rl/__init__.py#L69
    """

    N = dataset['rewards'].shape[0]

    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    terminal_ = []
    timestep_ = []

    timestep = 0
    for i in range(N - 1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i + 1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        final_timestep = dataset['timeouts'][i]
        terminal = done_bool or final_timestep

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        terminal_.append(terminal)
        timestep_.append(timestep)

        timestep = 0 if terminal else timestep + 1

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(terminal_),
        'timesteps': np.array(timestep_),
    }

In [None]:
robot = 'IIWA'
obj = 'Dumbbell'
obst = 'ObjectDoor'
subtask = 'Trashcan'

# robot = 'Kinova3'
# obj = 'Hollowbox'
# obst = 'None'
# subtask = 'Trashcan'

# robot = 'Jaco'
# obj = 'Plate'
# obst = 'GoalWall'
# subtask = 'Shelf'

# robot = 'Panda'
# obj = 'Hollowbox'
# obst = 'ObjectDoor'
# subtask = 'Trashcan'

representative_indicators_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
modality_dims = representative_indicators_env.modality_dims


base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                            dataset_type='expert', 
                                            robot=robot, obj=obj, 
                                            obst=obst, task=subtask)
agent_dataset = transitions_dataset_with_timesteps(dataset)
agent_dataset, _ = remove_indicator_vectors(modality_dims, agent_dataset)
agent_obs = agent_dataset['observations']
agent_actions = agent_dataset['actions']
agent_next_obs = agent_dataset['next_observations']
agent_rewards = agent_dataset['rewards']
agent_terminals = agent_dataset['terminals']
agent_timesteps = agent_dataset['timesteps']
# agent_dataset = make_inputs(agent_dataset)
agent_dataset = agent_next_obs

base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/cond_diff_20/train/'
synthetic_dataset = load_single_synthetic_dataset(base_path=base_synthetic_data_path, 
                                                  robot=robot, obj=obj, 
                                                  obst=obst, task=subtask)
synthetic_obs = synthetic_dataset['observations']
synthetic_actions = synthetic_dataset['actions']
synthetic_next_obs = synthetic_dataset['next_observations']
synthetic_rewards = synthetic_dataset['rewards']
synthetic_terminals = synthetic_dataset['terminals']
# synthetic_dataset = make_inputs(synthetic_dataset)
synthetic_dataset = synthetic_next_obs

print(agent_dataset.shape, synthetic_dataset.shape)

In [None]:
agent_gripper = np.argmax(agent_actions, axis=1) == 7  # gripper action
synthetic_gripper = np.argmax(synthetic_actions, axis=1) == 7
print(agent_gripper.shape, synthetic_gripper.shape)

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, ignore_done=False)
print(env.modality_dims)

In [None]:
random_indices = np.random.choice(agent_dataset.shape[0], 2500, replace=False)
sampled_agent_data = agent_dataset[random_indices]
sampled_agent_gripper = agent_gripper[random_indices]
sampled_agent_timesteps = agent_timesteps[random_indices]

random_indices = np.random.choice(synthetic_dataset.shape[0], 2500, replace=False)
sampled_synthetic_data = synthetic_dataset[random_indices]
sampled_synthetic_gripper = synthetic_gripper[random_indices]

print(sampled_agent_data.shape, sampled_synthetic_data.shape)
print(sampled_agent_gripper.shape, sampled_synthetic_gripper.shape)

In [None]:
scaler = StandardScaler().fit(sampled_agent_data)
scaled_agent_data = scaler.transform(sampled_agent_data)
scaled_synthetic_data = scaler.transform(sampled_synthetic_data)

In [None]:
umap_model = umap.UMAP(n_components=2).fit(scaled_agent_data)
embedded_agent_data = umap_model.transform(scaled_agent_data)
embedded_synthetic_data = umap_model.transform(scaled_synthetic_data)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

scatter = axes[0].scatter(embedded_agent_data[:, 0], embedded_agent_data[:, 1], 
                          c=sampled_agent_timesteps, cmap='viridis', alpha=0.8)
fig.colorbar(scatter, ax=axes[0], label="Timestep")
axes[0].scatter(embedded_agent_data[sampled_agent_gripper, 0], 
                embedded_agent_data[sampled_agent_gripper, 1], 
                c='#ff7f0e', marker='x', s=50, label='Gripper Closed')
axes[0].set_title('Fit-Transformed Agent Data', fontsize=14)
axes[0].set_xlabel('Dimension 1', fontsize=14)
axes[0].set_ylabel('Dimension 2', fontsize=14)
axes[0].legend()

axes[1].scatter(embedded_synthetic_data[:, 0], embedded_synthetic_data[:, 1], 
                c='#1f77b4', alpha=0.6, edgecolor='k', s=20)
axes[1].scatter(embedded_synthetic_data[sampled_synthetic_gripper, 0], 
                embedded_synthetic_data[sampled_synthetic_gripper, 1], 
                c='#ff7f0e', marker='x', s=50, label='Gripper Closed')
axes[1].set_title('Transformed Synthetic Data', fontsize=14)
axes[1].set_xlabel('Dimension 1', fontsize=14)
axes[1].set_ylabel('Dimension 2', fontsize=14)
axes[1].legend()

fig.suptitle(f"Train Task: {robot}_{obj}_{obst}_{subtask}", fontsize=16)

plt.tight_layout()

filename = f"umap_{robot}_{obj}_{obst}_{subtask}.png"
plt.savefig(filename, dpi=300, bbox_inches='tight')

plt.show()