In [None]:
import composuite
from diffusion.utils import *
from corl.algorithms.offline.td3_bc import *
from corl.shared.buffer import *
from corl.shared.logger import *

from diffusion.utils import *
from collections import defaultdict
import composuite
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns


def identify_special_dimensions(data):

    integer_dims = []
    constant_dims = []
    
    for i in range(data.shape[1]):
        column = data[:, i]
        if np.all(np.equal(column, np.round(column))):
            integer_dims.append(i)
        elif np.all(column == column[0]):
            constant_dims.append(i)
    
    return integer_dims, constant_dims


def process_special_dimensions(synthetic_dataset, integer_dims, constant_dims):

    processed_dataset = {k: v.copy() for k, v in synthetic_dataset.items()}
    
    for key in ['observations', 'next_observations']:
        # Round integer dimensions
        if integer_dims:
            processed_dataset[key][:, integer_dims] = np.round(
                synthetic_dataset[key][:, integer_dims]
            )
        
        # Round constant dimensions to 2 decimal places
        if constant_dims:
            processed_dataset[key][:, constant_dims] = np.round(
                synthetic_dataset[key][:, constant_dims], 
                decimals=1
            )
    
    return processed_dataset

In [None]:
base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/cond_diff_25/test'

In [None]:
# 208 tasks

robot = 'Panda'
obj = 'Box'
obst = 'ObjectDoor'
subtask = 'Shelf'

representative_indicators_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
representative_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, ignore_done=False)
modality_dims = representative_env.modality_dims

dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                            dataset_type='expert', 
                                            robot=robot, obj=obj, 
                                            obst=obst, task=subtask)
agent_dataset = transitions_dataset(dataset)
agent_dataset, _ = remove_indicator_vectors(representative_indicators_env.modality_dims, agent_dataset)
agent_obs = agent_dataset['observations']
integer_dims, constant_dims = identify_special_dimensions(agent_obs)
agent_actions = agent_dataset['actions']
agent_next_obs = agent_dataset['next_observations']
agent_rewards = agent_dataset['rewards']
agent_terminals = agent_dataset['terminals']
agent_dataset = make_inputs(agent_dataset)

synthetic_dataset = load_single_synthetic_dataset(base_path=base_synthetic_data_path, 
                                                  robot=robot, obj=obj, 
                                                  obst=obst, task=subtask)
synthetic_dataset = process_special_dimensions(synthetic_dataset, integer_dims, constant_dims)
synthetic_obs = synthetic_dataset['observations']
synthetic_actions = synthetic_dataset['actions']
synthetic_next_obs = synthetic_dataset['next_observations']
synthetic_rewards = synthetic_dataset['rewards']
synthetic_terminals = synthetic_dataset['terminals']
synthetic_dataset = make_inputs(synthetic_dataset)

print(agent_dataset.shape, synthetic_dataset.shape)

In [None]:
dataset1 = agent_obs
dataset2 = synthetic_obs

dim_names = []
dim_to_modality = {}
start_idx = 0

for modality, (size,) in modality_dims.items():
    for i in range(size):
        dim_names.append(f"{modality} {i}")
        dim_to_modality[start_idx + i] = modality
    start_idx += size

for modality, (size,) in modality_dims.items():    
    modality_indices = [i for i, mod in dim_to_modality.items() if mod == modality]
    n_cols = 3
    n_rows = (size + n_cols - 1) // n_cols
    
    fig = plt.figure(figsize=(15, 4 * n_rows))
    
    for i, dim_idx in enumerate(modality_indices):
        plt.subplot(n_rows, n_cols, i + 1)
        
        plt.hist(dataset1[:, dim_idx], bins=30, alpha=0.5, label='Agent', color='blue', density=True)
        plt.hist(dataset2[:, dim_idx], bins=30, alpha=0.5, label='Synthetic', color='orange', density=True)
        
        plt.title(f"{dim_names[dim_idx]}", fontsize=10)
        plt.grid(True, alpha=0.3)
        
        if i == 0:
            plt.legend()
    
    plt.suptitle(f"{modality} distributions", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

In [None]:
dataset1 = agent_actions
dataset2 = synthetic_actions

num_action_dims = dataset1.shape[1]

n_cols = 3
n_rows = (num_action_dims + n_cols - 1) // n_cols

fig = plt.figure(figsize=(15, 4 * n_rows))

for i in range(num_action_dims):
    plt.subplot(n_rows, n_cols, i + 1)
    
    plt.hist(dataset1[:, i], bins=30, alpha=0.5, label='Agent', color='blue', density=True)
    plt.hist(dataset2[:, i], bins=30, alpha=0.5, label='Synthetic', color='orange', density=True)
    
    plt.title(f"Action {i}", fontsize=10)
    plt.grid(True, alpha=0.3)
    
    if i == 0:
        plt.legend()

plt.suptitle("Action Distributions", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()