In [None]:
import composuite
from diffusion.utils import *
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats


def compute_error(agent_dataset, synthetic_dataset):

    num_samples = 1000
    
    agent_observations = agent_dataset['observations']
    random_indices = np.random.choice(agent_observations.shape[0], num_samples, replace=False)
    sampled_agent_observations = agent_observations[random_indices]

    synthetic_observations = synthetic_dataset['observations']
    random_indices = np.random.choice(synthetic_observations.shape[0], num_samples, replace=False)
    sampled_synthetic_observations = synthetic_observations[random_indices]

    mean_agent = np.mean(sampled_agent_observations, axis=0)
    mean_synthetic = np.mean(sampled_synthetic_observations, axis=0)

    error = np.linalg.norm(mean_agent - mean_synthetic)
    
    return error


def wasserstein_distance(agent_dataset, synthetic_dataset, num_samples=1000):
    
    agent_observations = agent_dataset['observations']
    random_indices = np.random.choice(agent_observations.shape[0], num_samples, replace=False)
    sampled_agent_observations = agent_observations[random_indices]

    synthetic_observations = synthetic_dataset['observations']
    random_indices = np.random.choice(synthetic_observations.shape[0], num_samples, replace=False)
    sampled_synthetic_observations = synthetic_observations[random_indices]
    
    n_dims = sampled_agent_observations.shape[1]
    w_distances = []
    
    for dim in range(n_dims):
        w_dist = stats.wasserstein_distance(sampled_agent_observations[:, dim], sampled_synthetic_observations[:, dim])
        w_distances.append(w_dist)
    
    return np.mean(w_distances), w_distances

In [None]:
dataset_type = 'expert'
base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'

In [None]:
representative_env = composuite.make('IIWA', 'Plate', 'ObjectWall', 'Push', use_task_id_obs=True, ignore_done=False)

In [None]:
runs = ['cond_diff_17', 'cond_diff_7', 'cond_diff_8', 
        'cond_diff_10', 'cond_diff_15', 'cond_diff_19', 
        'cond_diff_22', 'cond_diff_18', 'cond_diff_23', 
        'cond_diff_20', 'cond_diff_24', 'cond_diff_25', 
        'cond_diff_21']
num_train_tasks = [16, 32, 48, 64, 80, 96, 112, 128, 144, 176, 192, 208, 244]

In [None]:
# run = 'cond_diff_17' 

# train_tasks = [
#     task for task in os.listdir(os.path.join(base_synthetic_data_path, run, 'train'))
#     if not task.startswith('.')
# ]

# task = train_tasks[0]

# robot, obj, obst, subtask = task.split('_')

# agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
#                                                 dataset_type='expert', 
#                                                 robot=robot, obj=obj, 
#                                                 obst=obst, task=subtask)
# agent_dataset = transitions_dataset(agent_dataset)
# agent_dataset, _ = remove_indicator_vectors(representative_env.modality_dims, agent_dataset)

# synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, run, 'train'),
#                                                 robot=robot, obj=obj, 
#                                                 obst=obst, task=subtask)

In [None]:
# mean_dist, w_dists = wasserstein_distance(agent_dataset, synthetic_dataset, num_samples=1000)

In [None]:
all_train_error_means = []
all_train_error_stds = []
all_test_error_means = []
all_test_error_stds = []

for run in tqdm(runs, desc='Run'):
    train_tasks = [
        task for task in os.listdir(os.path.join(base_synthetic_data_path, run, 'train'))
        if not task.startswith('.')
    ]
    errors = []
    for task in train_tasks:
        robot, obj, obst, subtask = task.split('_')
        agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                                       dataset_type='expert', 
                                                       robot=robot, obj=obj, 
                                                       obst=obst, task=subtask)
        agent_dataset = transitions_dataset(agent_dataset)
        agent_dataset, _ = remove_indicator_vectors(representative_env.modality_dims, agent_dataset)

        synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, run, 'train'),
                                                          robot=robot, obj=obj, 
                                                          obst=obst, task=subtask)
        # errors.append(compute_error(agent_dataset, synthetic_dataset))
        mean_wasserstein_distance, _ = wasserstein_distance(agent_dataset, synthetic_dataset)
        errors.append(mean_wasserstein_distance)
    all_train_error_means.append(np.mean(errors))
    all_train_error_stds.append(np.std(errors))

    test_tasks = [
        task for task in os.listdir(os.path.join(base_synthetic_data_path, run, 'test'))
        if not task.startswith('.')
    ]
    errors = []
    for task in test_tasks:
        robot, obj, obst, subtask = task.split('_')
        agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                                       dataset_type='expert', 
                                                       robot=robot, obj=obj, 
                                                       obst=obst, task=subtask)
        agent_dataset = transitions_dataset(agent_dataset)
        agent_dataset, _ = remove_indicator_vectors(representative_env.modality_dims, agent_dataset)

        synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, run, 'test'),
                                                        robot=robot, obj=obj, 
                                                        obst=obst, task=subtask)
        # errors.append(compute_error(agent_dataset, synthetic_dataset))
        mean_wasserstein_distance, _ = wasserstein_distance(agent_dataset, synthetic_dataset)
        errors.append(mean_wasserstein_distance)
    all_test_error_means.append(np.mean(errors))
    all_test_error_stds.append(np.std(errors))

In [None]:
train_means = np.array(all_train_error_means)
train_stds = np.array(all_train_error_stds)
test_means = np.array(all_test_error_means)
test_stds = np.array(all_test_error_stds)
num_tasks = np.array(num_train_tasks)

num_tasks = num_tasks[:10]

plt.figure(figsize=(8, 6))

plt.plot(num_tasks, train_means, label='Train Error', marker='o', linestyle='-', color='blue')
plt.plot(num_tasks, test_means, label='Test Error', marker='s', linestyle='-', color='orange')

plt.fill_between(
    num_tasks, 
    train_means - train_stds, 
    train_means + train_stds, 
    color='blue', alpha=0.4
)
plt.fill_between(
    num_tasks, 
    test_means - test_stds, 
    test_means + test_stds, 
    color='orange', alpha=0.4
)

plt.xlabel('Number of Training Tasks', fontsize=14)
plt.ylabel('Wasserstein Distance', fontsize=14)
plt.title('Diffusion Model Generalization', fontsize=14)
plt.legend()
plt.grid(True)

plt.savefig('wasserstein_dist_diffusion_generalization.pdf', format='pdf', bbox_inches='tight')

plt.show()