In [1]:
import composuite
from diffusion.utils import *
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/'

In [None]:
dataset_type = 'expert'

run = 20
mode = 'train'

robot = 'IIWA'
obj = 'Dumbbell'
obst = 'ObjectDoor'
subtask = 'Trashcan'

representative_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, ignore_done=False)
representative_indicators_env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
modality_dims = representative_indicators_env.modality_dims

agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                               dataset_type=dataset_type, 
                                               robot=robot, obj=obj, 
                                               obst=obst, task=subtask)
agent_dataset = transitions_dataset(agent_dataset)
agent_dataset, indicators = remove_indicator_vectors(modality_dims, agent_dataset)
agent_dataset = make_inputs(agent_dataset)

synthetic_data_path = os.path.join(base_synthetic_data_path, f"cond_diff_{run}", mode)
synthetic_dataset = load_single_synthetic_dataset(base_path=synthetic_data_path, 
                                                  robot=robot, obj=obj, 
                                                  obst=obst, task=subtask)
synthetic_dataset = make_inputs(synthetic_dataset)

In [None]:
print(agent_dataset.shape, synthetic_dataset.shape)

In [4]:
random_indices = np.random.choice(agent_dataset.shape[0], 25000, replace=False)
sampled_agent_data = agent_dataset[random_indices]


random_indices = np.random.choice(synthetic_dataset.shape[0], 25000, replace=False)
sampled_synthetic_data = synthetic_dataset[random_indices]

In [None]:
idx = np.random.choice(sampled_agent_data.shape[0])

vector = sampled_agent_data[idx, :].reshape(1, -1)
plt.figure(figsize=(10, 2))
plt.imshow(vector, cmap="viridis", aspect="auto")
plt.colorbar(label="Value")
plt.yticks([])
plt.show()

vector = sampled_synthetic_data[idx, :].reshape(1, -1)
plt.figure(figsize=(10, 2))
plt.imshow(vector, cmap="viridis", aspect="auto")
plt.colorbar(label="Value")
plt.yticks([])
plt.show()

In [None]:
print(sampled_agent_data.shape, sampled_synthetic_data.shape)

In [7]:
# sampled_agent_data = np.random.rand(*sampled_agent_data.shape)
# sampled_synthetic_data = np.random.rand(*sampled_synthetic_data.shape)

In [None]:
print(sampled_agent_data.shape, sampled_synthetic_data.shape)

In [9]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x
    
# Training loop
def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for inputs, labels in train_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")

# Evaluation loop
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            predicted = (outputs > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    print(f"Accuracy: {correct / total:.4f}")

In [10]:
composuite_labels = np.zeros((sampled_agent_data.shape[0], 1))
synthetic_labels = np.ones((sampled_synthetic_data.shape[0], 1))

data = np.vstack((sampled_agent_data, sampled_synthetic_data)).astype(np.float32)
labels = np.vstack((composuite_labels, synthetic_labels)).astype(np.float32)

data_tensor = torch.tensor(data)
labels_tensor = torch.tensor(labels)

dataset = TensorDataset(data_tensor, labels_tensor)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
input_dim = data.shape[1]
model = SimpleMLP(input_dim)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_model(model, train_loader, criterion, optimizer, epochs=10)
evaluate_model(model, test_loader)