In [1]:
from diffusion.utils import *
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
import torch.optim as optim

base_composuite_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/results'

def process_data(transitions_datasets):

    combined_dict = defaultdict(list)
    which_dataset = []

    for idx, data in enumerate(transitions_datasets):
        for key in data.keys():
            combined_dict[key].append(data[key])
        which_dataset.extend([idx] * len(data[key]))

    combined_transitions_datasets = {key: np.concatenate(values, axis=0) for key, values in combined_dict.items()}
    processed_data = make_inputs(combined_transitions_datasets)

    return processed_data, np.array(which_dataset)

In [2]:
dataset_type = 'expert-iiwa-offline-comp-data'

robots = ['IIWA']
# objs = ['Box']
objs = ['Box', 'Dumbbell', 'Hollowbox', 'Plate']
obsts = ['None']
tasks = ['Push']

In [3]:
composuite_datasets = load_multiple_composuite_datasets(base_path=base_composuite_data_path,
                                                        dataset_type=dataset_type,
                                                        robots=robots, objs=objs, 
                                                        obsts=obsts, tasks=tasks)

composuite_transitions_datasets = [transitions_dataset(dataset) for dataset in composuite_datasets]
flattened_composuite_data, _ = process_data(composuite_transitions_datasets)

Loading data: 100%|██████████| 4/4 [00:07<00:00,  1.77s/it]


In [4]:
synthetic_datasets = load_multiple_synthetic_datasets(base_path=base_synthetic_data_path,
                                                      robots=robots, objs=objs, 
                                                      obsts=obsts, tasks=tasks)
flattened_synthetic_data, _ = process_data(synthetic_datasets)

Loading data: 100%|██████████| 4/4 [00:00<00:00, 285.00it/s]


In [5]:
print(flattened_composuite_data.shape, flattened_synthetic_data.shape)

(3999996, 196) (400000, 196)


In [6]:
random_indices = np.random.choice(flattened_composuite_data.shape[0], 25000, replace=False)
sampled_composuite_data = flattened_composuite_data[random_indices]

random_indices = np.random.choice(flattened_synthetic_data.shape[0], 25000, replace=False)
sampled_synthetic_data = flattened_synthetic_data[random_indices]

In [7]:
print(sampled_composuite_data.shape, sampled_synthetic_data.shape)

(25000, 196) (25000, 196)


In [8]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x
    
# Training loop
def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for inputs, labels in train_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")

# Evaluation loop
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            predicted = (outputs > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    print(f"Accuracy: {correct / total:.4f}")

In [9]:
# sampled_composuite_data = np.random.rand(1000, 196)
# sampled_synthetic_data = np.random.rand(1000, 196)

In [10]:
composuite_labels = np.zeros((sampled_composuite_data.shape[0], 1))
synthetic_labels = np.ones((sampled_synthetic_data.shape[0], 1))

data = np.vstack((sampled_composuite_data, sampled_synthetic_data)).astype(np.float32)
labels = np.vstack((composuite_labels, synthetic_labels)).astype(np.float32)

data_tensor = torch.tensor(data)
labels_tensor = torch.tensor(labels)

dataset = TensorDataset(data_tensor, labels_tensor)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [11]:
input_dim = data.shape[1]
model = SimpleMLP(input_dim)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_model(model, train_loader, criterion, optimizer, epochs=10)
evaluate_model(model, test_loader)

Epoch 1/10, Loss: 0.3018
Epoch 2/10, Loss: 0.1135
Epoch 3/10, Loss: 0.0141
Epoch 4/10, Loss: 0.0061
Epoch 5/10, Loss: 0.0043
Epoch 6/10, Loss: 0.0340
Epoch 7/10, Loss: 0.0142
Epoch 8/10, Loss: 0.0006
Epoch 9/10, Loss: 0.0017
Epoch 10/10, Loss: 0.0139
Accuracy: 0.9907
