In [None]:
import composuite
from diffusion.utils import *
from diffusion.utils import *
from collections import defaultdict
import composuite
from sklearn.preprocessing import StandardScaler
import umap
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import StandardScaler


def split_dict(data, split_ratio=0.7):

    train_dict, test_dict = {}, {}
    for key, values in data.items():
        values = np.array(values)  # Ensure it's an array
        indices = np.random.permutation(len(values))
        split_idx = int(len(values) * split_ratio)
        train_dict[key] = values[indices[:split_idx]]
        test_dict[key] = values[indices[split_idx:]]

    return train_dict, test_dict

In [None]:
robot = 'IIWA'
obj = 'Dumbbell'
obst = 'ObjectDoor'
subtask = 'Trashcan'

representative_indicators_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
modality_dims = representative_indicators_env.modality_dims


base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                            dataset_type='expert', 
                                            robot=robot, obj=obj, 
                                            obst=obst, task=subtask)
agent_dataset = transitions_dataset(dataset)
agent_dataset, _ = remove_indicator_vectors(modality_dims, agent_dataset)

base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/cond_diff_20/train/'
synthetic_dataset = load_single_synthetic_dataset(base_path=base_synthetic_data_path, 
                                                  robot=robot, obj=obj, 
                                                  obst=obst, task=subtask)

print(agent_dataset['observations'].shape, synthetic_dataset['observations'].shape)

In [None]:
train_agent_dataset, test_agent_dataset = split_dict(agent_dataset, split_ratio=0.9)

In [None]:
print(train_agent_dataset['observations'].shape, test_agent_dataset['observations'].shape, synthetic_dataset['observations'].shape)

In [None]:
class DynamicsModel(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.mean_head = nn.Linear(hidden_dim, obs_dim)
        self.log_std_head = nn.Linear(hidden_dim, obs_dim)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        h = self.fc(x)
        mean = self.mean_head(h)
        log_std = self.log_std_head(h)
        log_std = torch.clamp(log_std, min=-5, max=1)  # Prevent extreme values
        std = torch.exp(log_std)  # Ensure positivity
        return mean, std

def load_data(offline_data, batch_size=512, val_ratio=0.25):
    states = offline_data['observations']
    actions = offline_data['actions']
    next_states = offline_data['next_observations']
    
    state_scaler = StandardScaler()
    action_scaler = StandardScaler()
    next_state_scaler = StandardScaler()
    
    states = state_scaler.fit_transform(states)
    actions = action_scaler.fit_transform(actions)
    next_states = next_state_scaler.fit_transform(next_states)
    
    states = torch.tensor(states, dtype=torch.float32)
    actions = torch.tensor(actions, dtype=torch.float32)
    next_states = torch.tensor(next_states, dtype=torch.float32)
    
    dataset = TensorDataset(states, actions, next_states)
    val_size = int(len(dataset) * val_ratio)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    print('Train:', len(train_dataset), 'Val.:', len(val_dataset))
    
    return train_loader, val_loader, (state_scaler, action_scaler, next_state_scaler)

def train_dynamics_model(offline_data, obs_dim, action_dim, epochs=20, lr=3e-4):

    model = DynamicsModel(obs_dim, action_dim)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_loader, val_loader, scalers = load_data(offline_data)
    train_losses, val_losses = [] , []
    
    for epoch in range(epochs):
        total_train_loss = 0
        model.train()
        for state, action, next_state in train_loader:
            mean, std = model(state, action)
            loss = ((next_state - mean) / std).pow(2).mean() + std.mean()  # Negative log-likelihood loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for state, action, next_state in val_loader:
                mean, std = model(state, action)
                val_loss = ((next_state - mean) / std).pow(2).mean() + std.mean()
                total_val_loss += val_loss.item()
        
        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    
    return model, train_losses, val_losses, scalers

def evaluate_data(model, data, scalers):
    
    state_scaler, action_scaler, next_state_scaler = scalers
    
    states = state_scaler.transform(data['observations'])
    actions = action_scaler.transform(data['actions'])
    next_states = next_state_scaler.transform(data['next_observations'])
    
    states = torch.tensor(states, dtype=torch.float32)
    actions = torch.tensor(actions, dtype=torch.float32)
    next_states = torch.tensor(next_states, dtype=torch.float32)
    
    model.eval()
    with torch.no_grad():
        mean, std = model(states, actions)
        log_likelihood = -(((next_states - mean) / std).pow(2) + torch.log(std)).mean(dim=1)
    
    return log_likelihood.numpy()

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, ignore_done=False)
model, train_losses, val_losses, scalers = train_dynamics_model(train_agent_dataset, env.obs_dim, env.action_dim)

In [None]:
plt.figure()
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()

In [None]:
agent_likelihoods = evaluate_data(model, agent_dataset, scalers)
agent_likelihoods = np.exp(agent_likelihoods)

plt.hist(agent_likelihoods, bins=30)
plt.xlabel('Likelihood')
plt.ylabel('Frequency')
plt.title('Likelihoods of Agent Transitions')
plt.show()

In [None]:
train_agent_likelihoods = evaluate_data(model, train_agent_dataset, scalers)
train_agent_likelihoods = np.exp(train_agent_likelihoods)

plt.hist(train_agent_likelihoods, bins=30)
plt.xlabel('Likelihood')
plt.ylabel('Frequency')
plt.title('Train: Likelihoods of Agent Transitions')
plt.show()

In [None]:
test_agent_likelihoods = evaluate_data(model, test_agent_dataset, scalers)
test_agent_likelihoods = np.exp(test_agent_likelihoods)

plt.hist(train_agent_likelihoods, bins=30)
plt.xlabel('Likelihood')
plt.ylabel('Frequency')
plt.title('Test: Likelihoods of Agent Transitions')
plt.show()

In [None]:
synthetic_likelihoods = evaluate_data(model, synthetic_dataset, scalers)
synthetic_likelihoods = np.exp(synthetic_likelihoods)

plt.hist(synthetic_likelihoods, bins=30)
plt.xlabel('Likelihood')
plt.ylabel('Frequency')
plt.title('Likelihoods of Synthetic Transitions')
plt.show()

In [None]:
print(train_agent_likelihoods.mean().item(), test_agent_likelihoods.mean().item(), synthetic_likelihoods.mean().item())

In [None]:
# plt.hist(train_agent_likelihoods, bins=30, alpha=0.5, label="Train: Agent Transitions")
# plt.hist(test_agent_likelihoods, bins=30, alpha=0.5, label="Test: Agent Transitions")
plt.hist(agent_likelihoods, bins=30, alpha=0.5, label="Agent Transitions")
plt.hist(synthetic_likelihoods, bins=30, alpha=0.5, label="Synthetic Transitions")

plt.xlabel("Likelihood")
plt.ylabel("Frequency")
plt.title("Likelihoods of Agent and Synthetic Transitions")
plt.legend()
plt.show()

In [None]:
def filter_top_likelihood_transitions(synthetic_data, likelihoods, top_percent=1):

    num_top = int(len(likelihoods) * (top_percent/100))
    top_indices = np.argsort(likelihoods)[-num_top:]
    filtered_data = {key: val[top_indices] for key, val in synthetic_data.items()}

    return filtered_data


In [None]:
filtered_synthetic_dataset = filter_top_likelihood_transitions(synthetic_dataset, synthetic_likelihoods)

In [None]:
import pickle

filename = 'filtered_transitions.pkl'

with open(filename, 'wb') as f:
    pickle.dump(filtered_synthetic_dataset, f)