In [17]:
from diffusion.utils import *
from collections import defaultdict
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

dataset_type = 'expert'
base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'

import numpy as np

def process_data(datasets):

    combined_data = []
    which_dataset = []

    for idx, data in enumerate(datasets):
        combined_data.append(data)
        which_dataset.extend([idx] * len(data))
    
    combined_data = np.concatenate(combined_data, axis=0)
    which_dataset = np.array(which_dataset)
    
    return combined_data, which_dataset

In [None]:
representative_env = composuite.make('IIWA', 'Plate', 'ObjectWall', 'Push', use_task_id_obs=True, ignore_done=False)

In [19]:
run = 'cond_diff_20'

In [None]:
agent_datasets_train = []
synthetic_datasets_train = []

train_tasks = [
    task for task in os.listdir(os.path.join(base_synthetic_data_path, run, 'train'))
    if not task.startswith('.')
]

train_tasks = train_tasks[:3]

for task in tqdm(train_tasks, desc='Loading train task data'):
    robot, obj, obst, subtask = task.split('_')
    agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                                    dataset_type='expert', 
                                                    robot=robot, obj=obj, 
                                                    obst=obst, task=subtask)
    agent_dataset = transitions_dataset(agent_dataset)
    agent_dataset, _ = remove_indicator_vectors(representative_env.modality_dims, agent_dataset)
    agent_dataset = make_inputs(agent_dataset)
    random_indices = np.random.choice(agent_dataset.shape[0], 1000, replace=False)
    agent_dataset = agent_dataset[random_indices]
    agent_datasets_train.append(agent_dataset)

    synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, run, 'train'),
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
    synthetic_dataset = make_inputs(synthetic_dataset)
    random_indices = np.random.choice(synthetic_dataset.shape[0], 1000, replace=False)
    synthetic_dataset = synthetic_dataset[random_indices]
    synthetic_datasets_train.append(synthetic_dataset)

In [None]:
print(len(agent_datasets_train), len(synthetic_datasets_train))

In [None]:
agent_datasets_test = []
synthetic_datasets_test = [] 

test_tasks = [
    task for task in os.listdir(os.path.join(base_synthetic_data_path, run, 'test'))
    if not task.startswith('.')
]

test_tasks = test_tasks[:1]

for task in tqdm(test_tasks, desc='Loading test task data'):
    robot, obj, obst, subtask = task.split('_')
    agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                                    dataset_type='expert', 
                                                    robot=robot, obj=obj, 
                                                    obst=obst, task=subtask)
    agent_dataset = transitions_dataset(agent_dataset)
    agent_dataset, _ = remove_indicator_vectors(representative_env.modality_dims, agent_dataset)
    agent_dataset = make_inputs(agent_dataset)
    random_indices = np.random.choice(agent_dataset.shape[0], 1000, replace=False)
    agent_dataset = agent_dataset[random_indices]
    agent_datasets_test.append(agent_dataset)

    synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, run, 'test'),
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
    synthetic_dataset = make_inputs(synthetic_dataset)
    random_indices = np.random.choice(synthetic_dataset.shape[0], 1000, replace=False)
    synthetic_dataset = synthetic_dataset[random_indices]
    synthetic_datasets_test.append(synthetic_dataset)

In [None]:
print(len(agent_datasets_test), len(synthetic_datasets_test))

In [24]:
agent_datasets = agent_datasets_train + agent_datasets_test
synthetic_datasets = synthetic_datasets_train + synthetic_datasets_test

In [None]:
print(len(agent_datasets), len(synthetic_datasets))

In [26]:
flattened_agent_data, agent_data_ids = process_data(agent_datasets)
flattened_synthetic_data, synthetic_data_ids = process_data(synthetic_datasets)

In [None]:
print(flattened_agent_data.shape, agent_data_ids.shape)
print(flattened_synthetic_data.shape, synthetic_data_ids.shape)

In [28]:
random_indices = np.random.choice(flattened_agent_data.shape[0], 1000, replace=False)
sampled_agent_data = flattened_agent_data[random_indices]
samples_agent_data_ids = agent_data_ids[random_indices]

random_indices = np.random.choice(flattened_synthetic_data.shape[0], 1000, replace=False)
sampled_synthetic_data = flattened_synthetic_data[random_indices]
sampled_synthetic_data_ids = synthetic_data_ids[random_indices]

In [29]:
agent_origin = np.zeros(samples_agent_data_ids.shape[0], dtype=int)  # 0 for CompoSuite
synthetic_origin = np.ones(sampled_synthetic_data_ids.shape[0], dtype=int)   # 1 for synthetic

combined_data = np.concatenate((sampled_agent_data, sampled_synthetic_data), axis=0)
combined_ids = np.concatenate((samples_agent_data_ids, sampled_synthetic_data_ids), axis=0)
combined_origins = np.concatenate((agent_origin, synthetic_origin), axis=0)

In [30]:
normalized_data = StandardScaler().fit_transform(combined_data)
tsne = TSNE(n_components=2, random_state=42)
embeddings = tsne.fit_transform(normalized_data)

In [31]:
tasks = train_tasks + test_tasks
id_to_name = {idx: t for idx, t in enumerate(tasks)}

In [None]:
cmap = matplotlib.colormaps['viridis']
unique_ids = list(id_to_name.keys())
id_to_color = {id: cmap(i / len(unique_ids)) for i, id in enumerate(unique_ids)}

plt.figure(figsize=(10, 8))

scatter_composuite = plt.scatter(
    embeddings[combined_origins == 0, 0],
    embeddings[combined_origins == 0, 1],
    c=[id_to_color[id] for id in combined_ids[combined_origins == 0]],
    alpha=0.8,
    marker='o',
    s=10
)

scatter_synthetic = plt.scatter(
    embeddings[combined_origins == 1, 0],
    embeddings[combined_origins == 1, 1],
    c=[id_to_color[id] for id in combined_ids[combined_origins == 1]],
    alpha=0.25,
    marker='x',
    s=100
)

# Origin legend
origin_legend_elements = [
    Line2D([0], [0], marker='o', color='w', markeredgecolor='black', markersize=5, label='CompoSuite'),
    Line2D([0], [0], marker='x', color='w', markeredgecolor='black', markersize=5, label='Diffusion')
]

ax = plt.gca()
origin_legend = ax.legend(
    handles=origin_legend_elements,
    loc='upper right',
    title='Data Origin',
    fontsize=10,
    title_fontsize=10
)

# Environment legend
environment_legend_elements = [
    Line2D([0], [0], marker='o', color=id_to_color[id], linestyle='None', markersize=10, label=f'{id_to_name[id]}')
    for id in unique_ids
]

ax.legend(
    handles=environment_legend_elements,
    loc='lower right',
    title='Environment',
    fontsize=10,
    title_fontsize=10,
    bbox_to_anchor=(1, 0)
)

ax.add_artist(origin_legend)

plt.title('t-SNE on Agent and Synthetic Data', fontsize=14)
plt.savefig('tSNE_generalization_176.pdf', format='pdf', bbox_inches='tight')

plt.show()