In [None]:
import composuite
from diffusion.utils import *
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset, random_split


def combine_datasets(dataset_1, dataset_2):

    expected_keys = ['observations', 'actions', 'next_observations', 'rewards', 'terminals']
    combined_dataset = {}
    
    for key in expected_keys:
        dataset_1_shape = dataset_1[key].shape[1:]
        dataset_2_shape = dataset_2[key].shape[1:]
        
        if dataset_1_shape != dataset_2_shape:
            raise ValueError(f"Shape mismatch for {key}: Expert shape {dataset_1_shape} != Medium shape {dataset_2_shape}")
        
        combined_dataset[key] = np.concatenate([
            dataset_1[key],
            dataset_2[key]
        ], axis=0)
        
        print(f"Combined shape for {key}: {combined_dataset[key].shape}")
    
    return combined_dataset


def identify_special_dimensions(data):

    integer_dims = []
    constant_dims = []
    
    for i in range(data.shape[1]):
        column = data[:, i]
        if np.all(np.equal(column, np.round(column))):
            integer_dims.append(i)
        elif np.all(column == column[0]):
            constant_dims.append(i)
    
    return integer_dims, constant_dims


def process_special_dimensions(synthetic_dataset, integer_dims, constant_dims):

    processed_dataset = {k: v.copy() for k, v in synthetic_dataset.items()}
    
    for key in ['observations', 'next_observations']:
        # Round integer dimensions
        if integer_dims:
            processed_dataset[key][:, integer_dims] = np.round(
                synthetic_dataset[key][:, integer_dims]
            )
        
        # Round constant dimensions to 1 decimal place
        if constant_dims:
            processed_dataset[key][:, constant_dims] = np.round(
                synthetic_dataset[key][:, constant_dims], 
                decimals=1
            )
    
    return processed_dataset


def create_dimension_labels(state_dims):
    """Create mapping from dimension index to state meaning."""
    dim_to_label = {}
    current_dim = 0
    
    for state_type, dims in state_dims.items():
        n_dims = dims[0]  # Extract integer from tuple
        for i in range(n_dims):
            dim_to_label[current_dim + i] = f"{state_type}_{i}"
        current_dim += n_dims
    
    return dim_to_label


def compute_likelihoods(model, dataloader):

    criterion = nn.GaussianNLLLoss()
    likelihoods = []
    model.eval()
    with torch.no_grad():
        for obs, acts, next_obs in dataloader:
            mean, log_var = model(obs, acts)
            var = torch.exp(log_var)
            nll = criterion(mean, next_obs, var).cpu().numpy()
            likelihoods.append(-nll)

    return np.array(likelihoods)


def compute_dimensional_likelihoods(model, dataloader):
    """Compute per-dimension negative log likelihoods.
    Returns:
        all_dim_nlls: Array of shape (n_samples, n_dims) containing per-dimension NLL
    """
    all_dim_nlls = []
    model.eval()
    
    with torch.no_grad():
        for obs, acts, next_obs in dataloader:
            mean, log_var = model(obs, acts)
            var = torch.exp(log_var)
            # Compute per-dimension NLL
            # NLL = 0.5 * (log(2π) + log(σ²) + (x-μ)²/σ²)
            sq_mahalanobis = (next_obs - mean)**2 / var
            log_det = log_var
            dim_nll = 0.5 * (np.log(2 * np.pi) + log_det + sq_mahalanobis)
            all_dim_nlls.append(dim_nll.cpu().numpy())
    
    all_dim_nlls = np.concatenate(all_dim_nlls, axis=0)

    return all_dim_nlls


def analyze_dimensional_differences(train_dim_nlls, test_dim_nlls, model, dim_to_label):
    """ Analyze which dimensions contribute most to likelihood differences. """
    
    train_means = np.mean(train_dim_nlls, axis=0)
    test_means = np.mean(test_dim_nlls, axis=0)
    differences = test_means - train_means
    
    worst_dims = np.argsort(differences)[::-1]
    
    summary = []
    for dim in worst_dims:
        summary.append({
            'dimension': dim_to_label[dim],
            'dim_index': dim,
            'train_nll': train_means[dim],
            'test_nll': test_means[dim],
            'difference': differences[dim]
        })
    
    return summary

def visualize_dimensional_differences(train_dim_nlls, test_dim_nlls, model, dim_to_label):
    """Visualize per-dimension likelihood differences between train and test data."""
    
    train_means = np.mean(train_dim_nlls, axis=0)
    test_means = np.mean(test_dim_nlls, axis=0)
    differences = test_means - train_means
    
    worst_dims_idx = np.argsort(differences)[::-1]
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), height_ratios=[1.5, 1])
    
    n_dims = len(worst_dims_idx)
    top_k = min(10, n_dims)  # Show top 10 worst dimensions
    worst_dims = worst_dims_idx[:top_k]
    
    plot_data = []
    labels = []
    for dim in worst_dims:
        plot_data.extend([train_dim_nlls[:, dim], test_dim_nlls[:, dim]])
        labels.extend([f'Train {dim_to_label[dim]}', f'Test {dim_to_label[dim]}'])
    
    sns.boxplot(data=plot_data, ax=ax1)
    ax1.set_xticklabels(labels, rotation=45, ha='right')
    ax1.set_title('Distribution of Negative Log-Likelihoods by Dimension\n(Top 10 Most Different Dimensions)')
    ax1.set_ylabel('Negative Log-Likelihood')
    
    dims = np.arange(len(worst_dims))
    diffs = differences[worst_dims]
    
    bars = ax2.bar(dims, diffs)
    
    colors = plt.cm.RdYlBu_r(np.linspace(0, 1, len(bars)))
    for bar, color in zip(bars, colors):
        bar.set_color(color)
    
    ax2.set_xticks(dims)
    ax2.set_xticklabels([dim_to_label[d] for d in worst_dims], rotation=45, ha='right')
    ax2.set_title('Difference in Mean NLL (Test - Train)\nLarger Values = Worse Synthetic Data')
    ax2.set_ylabel('Difference in NLL')
    
    textstr = '\n'.join([
        'Summary Statistics:',
        f'Max Difference: {np.max(differences):.3f}',
        f'Mean Difference: {np.mean(differences):.3f}',
        f'Median Difference: {np.median(differences):.3f}'
    ])
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax2.text(0.02, 0.98, textstr, transform=ax2.transAxes, fontsize=9,
             verticalalignment='top', bbox=props)
    
    plt.tight_layout()
    plt.show()
    
    return [dim_to_label[dim] for dim in worst_dims], differences[worst_dims]


def visualize_state_space_differences(train_dim_nlls, test_dim_nlls, model, dim_to_label):
    """Visualize per-dimension likelihood differences with state space labels."""
    
    train_means = np.mean(train_dim_nlls, axis=0)
    test_means = np.mean(test_dim_nlls, axis=0)
    differences = test_means - train_means
    
    worst_dims_idx = np.argsort(differences)[::-1]
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), height_ratios=[1.5, 1])
    
    # 1. Box plot for top k worst dimensions
    n_dims = len(worst_dims_idx)
    top_k = min(15, n_dims)  # Show top 15 worst dimensions
    worst_dims = worst_dims_idx[:top_k]
    
    plot_data = []
    labels = []
    for dim in worst_dims:
        plot_data.extend([train_dim_nlls[:, dim], test_dim_nlls[:, dim]])
        labels.extend([f'Train\n{dim_to_label[dim]}', f'Test\n{dim_to_label[dim]}'])
    
    sns.boxplot(data=plot_data, ax=ax1)
    ax1.set_xticklabels(labels, rotation=45, ha='right')
    ax1.set_title('Distribution of Negative Log-Likelihoods by State Dimension\n(Top 15 Most Different Dimensions)')
    ax1.set_ylabel('Negative Log-Likelihood')
    
    # Group dimensions by state type
    state_type_dims = {}
    for dim, label in dim_to_label.items():
        state_type = label.split('_')[0]  # Extract state type from label
        if state_type not in state_type_dims:
            state_type_dims[state_type] = []
        state_type_dims[state_type].append(dim)
    
    # Calculate statistics for each state type
    state_type_diffs = {}
    for state_type, dims in state_type_dims.items():
        state_diffs = differences[dims]
        state_type_diffs[state_type] = {
            'mean': np.mean(state_diffs),
            'max': np.max(state_diffs),
            'worst_dim': dims[np.argmax(state_diffs)],
            'all_diffs': state_diffs
        }
    
    state_types = list(state_type_dims.keys())
    means = [state_type_diffs[st]['mean'] for st in state_types]
    maxes = [state_type_diffs[st]['max'] for st in state_types]
    
    x = np.arange(len(state_types))
    width = 0.35
    
    ax2.bar(x - width/2, means, width, label='Mean Difference')
    ax2.bar(x + width/2, maxes, width, label='Max Difference')
    
    ax2.set_xticks(x)
    ax2.set_xticklabels([st.replace('-state', '') for st in state_types], rotation=45, ha='right')
    ax2.set_title('NLL Differences by State Type (Test - Train)')
    ax2.set_ylabel('Difference in NLL')
    ax2.legend()
    
    textstr = ['State Space Analysis:']
    for state_type in state_types:
        stats = state_type_diffs[state_type]
        worst_dim_label = dim_to_label[stats['worst_dim']]
        textstr.extend([
            f'\n{state_type}:',
            f'  Mean Diff: {stats["mean"]:.3f}',
            f'  Max Diff: {stats["max"]:.3f}',
            f'  Worst Dim: {worst_dim_label}'
        ])
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax2.text(1.02, 0.98, '\n'.join(textstr), transform=ax2.transAxes, fontsize=9,
             verticalalignment='top', bbox=props)
    
    plt.tight_layout()
    plt.show()
    
    return state_type_diffs

In [None]:
base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/cond_diff_20/train/'

In [None]:
robot = 'IIWA'
obj = 'Dumbbell'
obst = 'ObjectDoor'
subtask = 'Trashcan'

representative_indicators_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
modality_dims = representative_indicators_env.modality_dims

In [None]:
medium_agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                                      dataset_type='medium', 
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
medium_agent_dataset = transitions_dataset(medium_agent_dataset)
medium_agent_dataset, _ = remove_indicator_vectors(modality_dims, medium_agent_dataset)

warmstart_agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                                      dataset_type='warmstart', 
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
warmstart_agent_dataset = transitions_dataset(warmstart_agent_dataset)
warmstart_agent_dataset, _ = remove_indicator_vectors(modality_dims, warmstart_agent_dataset)

expert_agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                                      dataset_type='expert', 
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
expert_agent_dataset = transitions_dataset(expert_agent_dataset)
expert_agent_dataset, _ = remove_indicator_vectors(modality_dims, expert_agent_dataset)


print(medium_agent_dataset['observations'].shape, 
      warmstart_agent_dataset['observations'].shape, 
      expert_agent_dataset['observations'].shape)

In [None]:
agent_dataset = combine_datasets(medium_agent_dataset, expert_agent_dataset)
agent_dataset = combine_datasets(agent_dataset, warmstart_agent_dataset)

In [None]:
synthetic_dataset = load_single_synthetic_dataset(base_path=base_synthetic_data_path, 
                                                  robot=robot, obj=obj, 
                                                  obst=obst, task=subtask)

In [None]:
integer_dims, constant_dims = identify_special_dimensions(agent_dataset['observations'])
print('Integer dimensions:', integer_dims)
print('Constant dimensions:', constant_dims)

In [None]:
synthetic_dataset = process_special_dimensions(synthetic_dataset, integer_dims, constant_dims)

In [None]:
print(agent_dataset['observations'].shape, synthetic_dataset['observations'].shape)

In [None]:
class Normalizer:
    def __init__(self, data, eps=1e-8):
        self.mean = np.mean(data, axis=0)
        self.std = np.std(data, axis=0) + eps

    def normalize(self, data):
        return (data - self.mean) / self.std

    def denormalize(self, normalized_data):
        return normalized_data * self.std + self.mean


def prepare_train_data(agent_dataset, split_ratio=0.9):

    obs_normalizer = Normalizer(agent_dataset['observations'])
    act_normalizer = Normalizer(agent_dataset['actions'])
    
    norm_obs = obs_normalizer.normalize(agent_dataset['observations'])
    norm_acts = act_normalizer.normalize(agent_dataset['actions'])
    norm_next_obs = obs_normalizer.normalize(agent_dataset['next_observations'])
    
    obs = torch.tensor(norm_obs, dtype=torch.float32)
    acts = torch.tensor(norm_acts, dtype=torch.float32)
    next_obs = torch.tensor(norm_next_obs, dtype=torch.float32)
    
    dataset = TensorDataset(obs, acts, next_obs)
    train_size = int(split_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_data, val_data = random_split(dataset, [train_size, val_size])
    
    return (DataLoader(train_data, batch_size=512, shuffle=True), 
            DataLoader(val_data, batch_size=512, shuffle=False),
            obs_normalizer, act_normalizer)


def prepare_test_data(synthetic_dataset, obs_normalizer, act_normalizer):

    norm_obs = obs_normalizer.normalize(synthetic_dataset['observations'])
    norm_acts = act_normalizer.normalize(synthetic_dataset['actions'])
    norm_next_obs = obs_normalizer.normalize(synthetic_dataset['next_observations'])
    
    obs = torch.tensor(norm_obs, dtype=torch.float32)
    acts = torch.tensor(norm_acts, dtype=torch.float32)
    next_obs = torch.tensor(norm_next_obs, dtype=torch.float32)
    
    dataset = TensorDataset(obs, acts, next_obs)

    return DataLoader(dataset, batch_size=512, shuffle=False)


class ProbabilisticDynamicsModel(nn.Module):

    def __init__(self, obs_dim, act_dim, obs_normalizer, act_normalizer, dropout_rate=0.1):
        super().__init__()

        self.obs_normalizer = obs_normalizer
        self.act_normalizer = act_normalizer

        self.shared_net = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        self.mean_head = nn.Linear(256, obs_dim)
        self.log_var_head = nn.Linear(256, obs_dim)
        self.log_var_scale = nn.Parameter(torch.ones(obs_dim) * 1.0)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            module.bias.data.zero_()
    
    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=-1)
        x = self.shared_net(x)
        mean = self.mean_head(x)
        log_var = self.log_var_scale * torch.tanh(self.log_var_head(x))
        return mean, log_var
    
    def predict(self, obs, act):
        """ Make predictions in the original (unnormalized) space. """

        norm_obs = torch.tensor(self.obs_normalizer.normalize(obs), dtype=torch.float32)
        norm_acts = torch.tensor(self.act_normalizer.normalize(act), dtype=torch.float32)
        
        with torch.no_grad():
            norm_mean, norm_log_var = self(norm_obs, norm_acts)
        
        mean = self.obs_normalizer.denormalize(norm_mean.numpy())
        log_var = norm_log_var.numpy() + 2 * np.log(self.obs_normalizer.std)
        
        return mean, log_var
    

def train_model(model, train_loader, val_loader, epochs=20, lr=3e-4, patience=5):

    criterion = nn.GaussianNLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    
    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    
    for epoch in range(epochs):

        model.train()
        train_loss = 0
        for obs, acts, next_obs in train_loader:
            optimizer.zero_grad()
            mean, log_var = model(obs, acts)
            var = torch.exp(log_var)
            loss = criterion(mean, next_obs, var)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for obs, acts, next_obs in val_loader:
                mean, log_var = model(obs, acts)
                var = torch.exp(log_var)
                loss = criterion(mean, next_obs, var)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
            
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
        
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break
    
    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return train_losses, val_losses

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, ignore_done=False)
dim_to_label = create_dimension_labels(env.modality_dims)

train_loader, val_loader, obs_normalizer, act_normalizer = prepare_train_data(agent_dataset)
test_loader = prepare_test_data(synthetic_dataset, obs_normalizer, act_normalizer)

model = ProbabilisticDynamicsModel(
    obs_dim=env.obs_dim, 
    act_dim=env.action_dim,
    obs_normalizer=obs_normalizer,
    act_normalizer=act_normalizer
)

train_losses, val_losses = train_model(model, train_loader, val_loader)

In [None]:
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss Curves')
plt.show()

In [None]:
train_likelihoods = compute_likelihoods(model, train_loader)
val_likelihoods = compute_likelihoods(model, val_loader)
test_likelihoods = compute_likelihoods(model, test_loader)

In [None]:
bins = np.linspace(min(map(np.min, [train_likelihoods, val_likelihoods])), 
                   max(map(np.max, [train_likelihoods, val_likelihoods])), 
                   15)

plt.hist(train_likelihoods, bins=bins, alpha=0.5, label='Train Set')
plt.hist(val_likelihoods, bins=bins, alpha=0.5, label='Val. Set')

plt.xlabel('Log Likelihood')
plt.ylabel('Frequency')
plt.legend()
plt.show()

In [None]:
bins = np.linspace(min(map(np.min, [train_likelihoods, test_likelihoods])), 
                   max(map(np.max, [train_likelihoods, test_likelihoods])), 
                   15)

plt.hist(train_likelihoods, bins=bins, alpha=0.5, label='Train Set')
plt.hist(test_likelihoods, bins=bins, alpha=0.5, label='Test Set')

plt.xlabel('Log Likelihood')
plt.ylabel('Frequency')
plt.legend()
plt.show()

In [None]:
bins = np.linspace(min(map(np.min, [train_likelihoods, val_likelihoods, test_likelihoods])), 
                   max(map(np.max, [train_likelihoods, val_likelihoods, test_likelihoods])), 
                   15)

plt.hist(train_likelihoods, bins=bins, alpha=0.5, label='Train Set')
plt.hist(val_likelihoods, bins=bins, alpha=0.5, label='Val. Set')
plt.hist(test_likelihoods, bins=bins, alpha=0.5, label='Test Set')

plt.xlabel('Log Likelihood')
plt.ylabel('Frequency')
plt.legend()
plt.show()

In [None]:
train_dim_nlls = compute_dimensional_likelihoods(model, train_loader)
test_dim_nlls = compute_dimensional_likelihoods(model, test_loader)

In [None]:
summary = analyze_dimensional_differences(train_dim_nlls, test_dim_nlls, model, dim_to_label)

print("\nDimensions sorted by likelihood difference (worst to best):")
for entry in summary:
    print(f"Dimension {entry['dimension']}: "
          f"Train NLL = {entry['train_nll']:6.3f}, "
          f"Test NLL = {entry['test_nll']:6.3f}, "
          f"Diff = {entry['difference']:6.3f}")

In [None]:
worst_dims, diffs = visualize_dimensional_differences(train_dim_nlls, test_dim_nlls, model, dim_to_label)

print("\nTop differences by dimension:")
for dim, diff in zip(worst_dims, diffs):
    print(f"Dimension {dim}: {diff:6.3f}")

In [None]:
state_diffs = visualize_state_space_differences(train_dim_nlls, test_dim_nlls, model, dim_to_label)

print("\nAnalysis by State Type:")
for state_type, stats in state_diffs.items():
    print(f"\n{state_type}:")
    print(f"  Mean Difference: {stats['mean']:.3f}")
    print(f"  Max Difference: {stats['max']:.3f}")
    print(f"  Worst Dimension: {stats['worst_dim']}")