In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from diffusion.utils import *
from corl.algorithms.offline.td3_bc import *
from corl.shared.buffer import *
from corl.shared.logger import *

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'

base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

def identify_special_dimensions(data):

    integer_dims = []
    constant_dims = []
    
    for i in range(data.shape[1]):
        column = data[:, i]
        if np.all(np.equal(column, np.round(column))):
            integer_dims.append(i)
        elif np.all(column == column[0]):
            constant_dims.append(i)
    
    return integer_dims, constant_dims


def process_special_dimensions(synthetic_dataset, integer_dims, constant_dims):

    processed_dataset = {k: v.copy() for k, v in synthetic_dataset.items()}
    
    for key in ['observations', 'next_observations']:
        # Round integer dimensions
        if integer_dims:
            processed_dataset[key][:, integer_dims] = np.round(
                synthetic_dataset[key][:, integer_dims]
            )
        
        # Round constant dimensions to 2 decimal places
        if constant_dims:
            processed_dataset[key][:, constant_dims] = np.round(
                synthetic_dataset[key][:, constant_dims], 
                decimals=2
            )
    
    return processed_dataset


def split_dataset(data_dict, train_ratio=0.8, seed=42):

    np.random.seed(seed)
    num_points = len(data_dict['observations'])
    indices = np.random.permutation(num_points)
    split_idx = int(train_ratio * num_points)
    train_indices = indices[:split_idx]
    test_indices = indices[split_idx:]
    
    train_data = {}
    test_data = {}
    
    for key in data_dict.keys():
        train_data[key] = data_dict[key][train_indices]
        test_data[key] = data_dict[key][test_indices]
    
    return train_data, test_data

In [None]:
config = TrainConfig()


synthetic_run_id = 'cond_diff_20'
mode = 'train'
config.max_timesteps = 30000
config.n_episodes = 10
config.batch_size = 1024

robot = 'IIWA'
obj = 'Dumbbell'
obst = 'ObjectDoor'
subtask = 'Trashcan'

env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                               dataset_type='expert', 
                                               robot=robot, obj=obj, 
                                               obst=obst, task=subtask)
agent_dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(agent_dataset))
train_agent_dataset, test_agent_dataset = split_dataset(agent_dataset)
integer_dims, constant_dims = identify_special_dimensions(train_agent_dataset['observations'])
print('Integer dimensions:', integer_dims)
print('Constant dimensions:', constant_dims)

In [None]:
base_results_path = pathlib.Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_agent_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_agent_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

config.checkpoints_path = results_folder

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = compute_mean_std(train_agent_dataset["observations"], eps=1e-3)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)

In [None]:
device = "cpu"
num_samples = int(train_agent_dataset['observations'].shape[0])
print("Samples:", num_samples)

replay_buffer = prepare_replay_buffer(
    state_dim=state_dim,
    action_dim=action_dim,
    dataset=train_agent_dataset,
    num_samples=num_samples,
    device=device,
    reward_normalizer=RewardNormalizer(train_agent_dataset, config.env) if config.normalize_reward else None,
    state_normalizer=StateNormalizer(state_mean, state_std),
    )

In [None]:
max_action = float(env.action_space.high[0])

In [None]:
logger = Logger(results_folder, seed=config.seed)

# Set seeds
seed = config.seed
set_seed(seed, env)

actor = Actor(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

critic_1 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)

critic_2 = Critic(state_dim, action_dim, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

In [None]:
kwargs = {
    "max_action": max_action,
    "actor": actor,
    "actor_optimizer": actor_optimizer,
    "critic_1": critic_1,
    "critic_1_optimizer": critic_1_optimizer,
    "critic_2": critic_2,
    "critic_2_optimizer": critic_2_optimizer,
    "discount": config.discount,
    "tau": config.tau,
    "policy_noise": config.policy_noise * max_action,
    "noise_clip": config.noise_clip * max_action,
    "policy_freq": config.policy_freq,
    "alpha": config.alpha,
}

print("----------------------------------------------------")
print(f"Training TD3 + BC, Env: {config.env}, Seed: {seed}")
print("----------------------------------------------------")

# Initialize actor
trainer = TD3_BC(**kwargs)

In [None]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'corl_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.name,
)

In [None]:
print(config.checkpoints_path)

In [None]:
evaluations = []
for t in range(int(config.max_timesteps)):
    batch = replay_buffer.sample(config.batch_size)
    batch = [b.to(config.device) for b in batch]
    log_dict = trainer.train(batch)

    if t % config.log_every == 0:
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='train')

    # Evaluate episode
    if t % config.eval_freq == 0 or t == config.max_timesteps - 1:
        print(f"Time steps: {t + 1}")
        eval_scores = eval_actor(
            env,
            actor,
            device=config.device,
            n_episodes=config.n_episodes,
            seed=config.seed,
        )
        eval_score = eval_scores.mean()
        evaluations.append(eval_score)
        print("------------------------------------------------")
        print(
            f"Evaluation over {config.n_episodes} episodes: "
            f"{eval_score:.3f}"
        )
        print("------------------------------------------------")
        if config.checkpoints_path is not None and config.save_checkpoints:
            torch.save(
                trainer.state_dict(),
                os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
            )
        log_dict = {"Score": eval_score}
        wandb.log(log_dict, step=trainer.total_it)
        logger.log({'step': trainer.total_it, **log_dict}, mode='eval')

In [None]:
synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                                  robot=robot, obj=obj, 
                                                  obst=obst, task=subtask)
synthetic_dataset = process_special_dimensions(synthetic_dataset, integer_dims, constant_dims)

In [None]:
def get_actions(observations, actor, batch_size):

    actor.eval()

    all_predicted_actions = []

    with torch.no_grad():
        for i in range(0, observations.shape[0], batch_size):
            batch_end = min(i + config.batch_size, synthetic_dataset['observations'].shape[0])
            batch = torch.tensor(observations[i:batch_end], dtype=torch.float32, device=device)
            predicted_actions = actor(batch)
            all_predicted_actions.append(predicted_actions.cpu().numpy())

    all_predicted_actions = np.vstack(all_predicted_actions)

    return all_predicted_actions

In [None]:
normalized_train_agent_observations = (train_agent_dataset['observations'] - state_mean)/state_std
predicted_train_agent_actions = get_actions(normalized_train_agent_observations, actor, config.batch_size)

normalized_test_agent_observations = (test_agent_dataset['observations'] - state_mean)/state_std
predicted_test_agent_actions = get_actions(normalized_test_agent_observations, actor, config.batch_size)

normalized_synthetic_observations = (synthetic_dataset['observations'] - state_mean)/state_std
predicted_synthetic_actions = get_actions(normalized_synthetic_observations, actor, config.batch_size)

In [None]:
print(predicted_train_agent_actions.shape, predicted_test_agent_actions.shape, predicted_synthetic_actions.shape)

In [None]:
def calculate_outlier_fraction(data):
    q1 = np.percentile(data, 25)
    q3 = np.percentile(data, 75)
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    outliers = np.sum((data < lower_bound) | (data > upper_bound))
    fraction_outliers = outliers / len(data)
    return fraction_outliers


def compare_action_predictions(train_agent_actions, predicted_train_agent_actions, 
                               test_agent_actions, predicted_test_agent_actions,
                               synthetic_actions, predicted_synthetic_actions,
                               n_samples=1000):

    train_agent_indices = np.random.choice(len(train_agent_actions), min(n_samples, len(train_agent_actions)), replace=False)
    test_agent_indices = np.random.choice(len(test_agent_actions), min(n_samples, len(test_agent_actions)), replace=False)
    synthetic_indices = np.random.choice(len(synthetic_actions), min(n_samples, len(synthetic_actions)), replace=False)
    
    train_agent_mse = np.mean((train_agent_actions[train_agent_indices] - predicted_train_agent_actions[train_agent_indices])**2, axis=1)
    test_agent_mse = np.mean((test_agent_actions[test_agent_indices] - predicted_test_agent_actions[test_agent_indices])**2, axis=1)
    synthetic_mse = np.mean((synthetic_actions[synthetic_indices] - predicted_synthetic_actions[synthetic_indices])**2, axis=1)
    
    train_outliers = calculate_outlier_fraction(train_agent_mse)
    test_outliers = calculate_outlier_fraction(test_agent_mse)
    synthetic_outliers = calculate_outlier_fraction(synthetic_mse)
    
    print(f"Train Agent Outliers: {train_outliers:.4f}")
    print(f"Test Agent Outliers: {test_outliers:.4f}")
    print(f"Synthetic Data Outliers: {synthetic_outliers:.4f}")

    data = [train_agent_mse, test_agent_mse, synthetic_mse]
    labels = ['Train Agent Data', 'Test Agent Data', 'Synthetic Data']

    plt.figure(figsize=(8, 6))
    sns.boxplot(data=data)
    plt.xticks(range(len(labels)), labels)
    plt.ylabel('Mean Squared Error')
    plt.title('Action Prediction Error Comparison')
    
    for i, d in enumerate(data):
        plt.text(i, np.mean(d), f'Mean: {np.mean(d):.4f}', 
                horizontalalignment='center', verticalalignment='bottom')
    
    return plt.gcf()

def compare_action_predictions_per_dim(train_agent_actions, predicted_train_agent_actions, 
                                       test_agent_actions, predicted_test_agent_actions,
                                       synthetic_actions, predicted_synthetic_actions,
                                       n_samples=1000):

    n_dimensions = train_agent_actions.shape[1]
    dimension_mse = []
    
    for dim in range(n_dimensions):
        train_agent_indices = np.random.choice(len(train_agent_actions), min(n_samples, len(train_agent_actions)), replace=False)
        test_agent_indices = np.random.choice(len(test_agent_actions), min(n_samples, len(test_agent_actions)), replace=False)
        synthetic_indices = np.random.choice(len(synthetic_actions), min(n_samples, len(synthetic_actions)), replace=False)
        
        train_agent_mse = (train_agent_actions[train_agent_indices, dim] - predicted_train_agent_actions[train_agent_indices, dim])**2
        test_agent_mse = (test_agent_actions[test_agent_indices, dim] - predicted_test_agent_actions[test_agent_indices, dim])**2
        synthetic_mse = (synthetic_actions[synthetic_indices, dim] - predicted_synthetic_actions[synthetic_indices, dim])**2
        
        train_outliers = calculate_outlier_fraction(train_agent_mse)
        test_outliers = calculate_outlier_fraction(test_agent_mse)
        synthetic_outliers = calculate_outlier_fraction(synthetic_mse)
        
        print(f"Dimension {dim + 1}:")
        print(f"Train Agent Outliers: {train_outliers:.4f}")
        print(f"Test Agent Outliers: {test_outliers:.4f}")
        print(f"Synthetic Data Outliers: {synthetic_outliers:.4f}")
        
        mean_mse = np.mean([train_agent_mse, test_agent_mse, synthetic_mse])
        dimension_mse.append((dim, mean_mse))
        
        data = [train_agent_mse, test_agent_mse, synthetic_mse]
        labels = ['Train Agent Data', 'Test Agent Data', 'Synthetic Data']
        
        plt.figure(figsize=(8, 6))
        sns.boxplot(data=data)
        plt.xticks(range(len(labels)), labels)
        plt.ylabel(f'Mean Squared Error (Dimension {dim+1})')
        plt.title(f'Action Prediction Error Comparison (Dimension {dim+1})')

        for i, d in enumerate(data):
            plt.text(i, np.mean(d), f'Mean: {np.mean(d):.4f}', 
                     horizontalalignment='center', verticalalignment='bottom')

        plt.show()
    
    # Sort dimensions based on MSE from worst to best (highest to lowest MSE)
    dimension_mse.sort(key=lambda x: x[1], reverse=True)
    sorted_dimensions = [dim for dim, _ in dimension_mse]
    return sorted_dimensions

In [None]:
n_samples = min(test_agent_dataset['actions'].shape[0], synthetic_dataset['actions'].shape[0])
n_samples = 1000

In [None]:
fig = compare_action_predictions(
    train_agent_dataset['actions'], predicted_train_agent_actions,
    test_agent_dataset['actions'], predicted_test_agent_actions,
    synthetic_dataset['actions'], predicted_synthetic_actions,
    n_samples=n_samples
    )

In [None]:
sorted_dimensions = compare_action_predictions_per_dim(
    train_agent_dataset['actions'], predicted_train_agent_actions,
    test_agent_dataset['actions'], predicted_test_agent_actions,
    synthetic_dataset['actions'], predicted_synthetic_actions,
    n_samples=n_samples
    )

In [None]:
print(sorted_dimensions)