In [1]:
import composuite
from diffusion.utils import *
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

base_composuite_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/results'

def process_data(transitions_datasets):

    combined_dict = defaultdict(list)
    which_dataset = []

    for idx, data in enumerate(transitions_datasets):
        for key in data.keys():
            combined_dict[key].append(data[key])
        which_dataset.extend([idx] * len(data[key]))

    combined_transitions_datasets = {key: np.concatenate(values, axis=0) for key, values in combined_dict.items()}
    processed_data = make_inputs(combined_transitions_datasets)

    return processed_data, np.array(which_dataset)

def remove_indicator_vectors(data, env):
    obs_dim = env.obs_dim 
    action_dim = env.action_dim
    dims = env.modality_dims

    start_index = sum([dim[0] for key, dim in dims.items() if key in ['object-state', 'obstacle-state', 'goal-state']])
    end_index = start_index + sum([dim[0] for key, dim in dims.items() if key in ['object_id', 'robot_id', 'obstacle_id', 'subtask_id']])

    def remove_dims(data, start, end):
        return np.delete(data, slice(start, end), axis=1)

    observations = data[:, :obs_dim]
    observations = remove_dims(observations, start_index, end_index)
    actions = data[:, obs_dim:obs_dim + action_dim]
    rewards = data[:, obs_dim + action_dim:obs_dim + action_dim + 1]
    next_observations = data[:, obs_dim + action_dim + 1:2*obs_dim + action_dim + 1]
    next_observations = remove_dims(next_observations, start_index, end_index)
    terminals = data[:, -1:]
    data = np.hstack([observations, actions, rewards, next_observations, terminals])
    
    return data

In [2]:
dataset_type = 'expert'

robots = ['IIWA']
objs = ['Box']
obsts = ['None']
tasks = ['Push']

In [None]:
composuite_datasets = load_multiple_composuite_datasets(base_path=base_composuite_data_path,
                                                        dataset_type=dataset_type,
                                                        robots=robots, objs=objs, 
                                                        obsts=obsts, tasks=tasks)

composuite_transitions_datasets = [transitions_dataset(dataset) for dataset in composuite_datasets]
flattened_composuite_data, _ = process_data(composuite_transitions_datasets)

In [None]:
synthetic_datasets = load_multiple_synthetic_datasets(base_path=base_synthetic_data_path,
                                                      robots=robots, objs=objs, 
                                                      obsts=obsts, tasks=tasks)
flattened_synthetic_data, _ = process_data(synthetic_datasets)

In [None]:
print(flattened_composuite_data.shape, flattened_synthetic_data.shape)

In [None]:
env = composuite.make(robots[0], objs[0], obsts[0], tasks[0], use_task_id_obs=True, ignore_done=False)

random_indices = np.random.choice(flattened_composuite_data.shape[0], 25000, replace=False)
sampled_composuite_data = flattened_composuite_data[random_indices]
sampled_composuite_data = remove_indicator_vectors(sampled_composuite_data, env)

random_indices = np.random.choice(flattened_synthetic_data.shape[0], 25000, replace=False)
sampled_synthetic_data = flattened_synthetic_data[random_indices]
sampled_synthetic_data = remove_indicator_vectors(sampled_synthetic_data, env)

In [None]:
idx = np.random.choice(sampled_composuite_data.shape[0])

vector = sampled_composuite_data[idx, :].reshape(1, -1)
plt.figure(figsize=(10, 2))
plt.imshow(vector, cmap="viridis", aspect="auto")
plt.colorbar(label="Value")
plt.yticks([])
plt.show()

vector = sampled_synthetic_data[idx, :].reshape(1, -1)
plt.figure(figsize=(10, 2))
plt.imshow(vector, cmap="viridis", aspect="auto")
plt.colorbar(label="Value")
plt.yticks([])
plt.show()

In [None]:
print(sampled_composuite_data.shape, sampled_synthetic_data.shape)

In [10]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x
    
# Training loop
def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for inputs, labels in train_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")

# Evaluation loop
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            predicted = (outputs > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    print(f"Accuracy: {correct / total:.4f}")

In [11]:
# sampled_composuite_data = np.random.rand(1000, 164)
# sampled_synthetic_data = np.random.rand(1000, 164)

In [12]:
composuite_labels = np.zeros((sampled_composuite_data.shape[0], 1))
synthetic_labels = np.ones((sampled_synthetic_data.shape[0], 1))

data = np.vstack((sampled_composuite_data, sampled_synthetic_data)).astype(np.float32)
labels = np.vstack((composuite_labels, synthetic_labels)).astype(np.float32)

data_tensor = torch.tensor(data)
labels_tensor = torch.tensor(labels)

dataset = TensorDataset(data_tensor, labels_tensor)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
input_dim = data.shape[1]
model = SimpleMLP(input_dim)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_model(model, train_loader, criterion, optimizer, epochs=10)
evaluate_model(model, test_loader)