In [26]:
import h5py

# Open the HDF5 file and print its structure to ensure it contains the expected keys
def print_structure(file_path):
    with h5py.File(file_path, 'r') as file:
        print("Keys in the file:")
        for key in file.keys():
            print(key)
            for subkey in file[key].keys():
                print(f"  {subkey}: {list(file[key][subkey].keys())}")

# Call the function with the correct file path
print_structure('release/data/metaworld/Assembly_frame_stack_1_96x96_end_on_success/train.hdf5')


Keys in the file:
data
  demo_0: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
  demo_1: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
  demo_2: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
  demo_3: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']


In [31]:
import h5py
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from tqdm import tqdm

class SparseDenseDataset(Dataset):
    def __init__(self, hdf5_path, transform=None):
        self.file = h5py.File(hdf5_path, 'r')
        self.transform = transform
        self.images = []
        self.props = []
        self.modes = []
        
        for demo_key in self.file['data'].keys():
            demo_group = self.file['data'][demo_key]
            # Extract and transpose the image data
            images = demo_group['obs']['corner2_image'][:]
            images = images.transpose(0, 2, 3, 1)  # Change from [N, H, W, C] to [N, C, H, W]
            self.images.append(images)
            
            self.props.append(demo_group['obs']['prop'][:])
            self.modes.append(demo_group['mode'][:])
        
        self.images = np.concatenate(self.images, axis=0)
        self.props = np.concatenate(self.props, axis=0)
        self.modes = np.concatenate(self.modes, axis=0)
    
    def __len__(self):
        return len(self.modes)  # This should return the total number of entries

    def __getitem__(self, idx):
        image = self.images[idx]
        prop = self.props[idx]
        mode = self.modes[idx]
        
        if self.transform:
            image = self.transform(image)  # Transform now expects image in [C, H, W]
        
        # Normalize prop data if needed
        prop = (prop - np.mean(prop)) / np.std(prop)
        
        return {'image': image, 'prop': prop, 'mode': mode}


# Transformations for the image data
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((96, 96)),  # Ensure image size if varying
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = SparseDenseDataset('release/data/metaworld/Assembly_frame_stack_1_96x96_end_on_success/train.hdf5', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


# Split the dataset into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# # Modify the classifier to output two classes (sparse and dense modes)
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, 2)  # Output for two classes

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [32]:
class HybridResNet(nn.Module):
    def __init__(self, num_prop_features):
        super(HybridResNet, self).__init__()
        # Load a pre-trained ResNet-18 model, discard the final fc layer
        base_model = resnet18(pretrained=True)
        self.features = nn.Sequential(*list(base_model.children())[:-1])
        
        # Adding a small network for processing `prop` data
        self.prop_processor = nn.Sequential(
            nn.Linear(num_prop_features, 50),
            nn.ReLU(),
            nn.Linear(50, 20),
            nn.ReLU()
        )
        
        # The final classifier that combines both features
        self.classifier = nn.Linear(512 + 20, 2)  # 512 for ResNet-18, 20 from prop_processor

    def forward(self, images, props):
        img_features = self.features(images)
        img_features = img_features.view(img_features.size(0), -1)  # Flatten the features
        
        prop_features = self.prop_processor(props)
        
        # Concatenate the features from both networks
        combined_features = torch.cat((img_features, prop_features), dim=1)
        
        # Final classification
        outputs = self.classifier(combined_features)
        return outputs

# Setup the model
model = HybridResNet(num_prop_features=4)  # Assuming `prop` has 4 features
model = model.to(device)

# Redefine criterion and optimizer as needed
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Adjusted training loop to pass both images and props
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        images = data['image'].to(device)
        props = data['prop'].to(device).float()
        modes = data['mode'].to(device).long()

        # Forward pass with both images and props
        outputs = model(images, props)
        loss = criterion(outputs, modes)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validation loss
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in val_loader:
            images = data['image'].to(device)
            props = data['prop'].to(device).float()
            modes = data['mode'].to(device).long()
            
            # Ensure both images and props are passed during the validation phase as well
            outputs = model(images, props)
            loss = criterion(outputs, modes)
            val_loss += loss.item()

    print(f"Training Loss: {running_loss / len(train_loader)}, Validation Loss: {val_loss / len(val_loader)}")


print("Training complete")


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16.72it/s]


Training Loss: 1.2889820158481597, Validation Loss: 8.393877744674683


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16.99it/s]


Training Loss: 0.3566894414834678, Validation Loss: 5.108250737190247


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16.62it/s]


Training Loss: 0.09665083140134811, Validation Loss: 1.9458958357572556


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 17.09it/s]


Training Loss: 0.0103766362182796, Validation Loss: 2.587146535515785


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16.98it/s]


Training Loss: 0.036093513434752825, Validation Loss: 0.6805043783970177


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 17.37it/s]


Training Loss: 0.08175538985524326, Validation Loss: 0.006220915878657252


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.13it/s]


Training Loss: 0.004559734219219536, Validation Loss: 1.3165777698159218


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 17.96it/s]


Training Loss: 0.04257897436618805, Validation Loss: 0.6592525923624635


Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 17.75it/s]


Training Loss: 0.005700486374553293, Validation Loss: 0.0030521782318828627


Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.17it/s]


Training Loss: 0.0027157488977536557, Validation Loss: 0.0003315820067655295
Training complete


In [36]:
torch.save(model.state_dict(), 'now.pth')

In [37]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import h5py
from torchvision import transforms

# Assuming the SparseDenseDataset and HybridResNet classes are defined as above
# Load the model
model_path = 'now.pth'  # Adjust path as necessary
model = HybridResNet(num_prop_features=4)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# Define the transformation and dataset
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((96, 96)),  # Ensure image size if varying
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = SparseDenseDataset('release/data/metaworld/Assembly_frame_stack_1_96x96_end_on_success/test.hdf5', transform=transform)  # Load your new dataset
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# Function to calculate accuracy
def calculate_accuracy(dataloader, model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images = data['image'].to(device)
            props = data['prop'].to(device).float()
            modes = data['mode'].to(device).long()
            outputs = model(images, props)
            _, predicted = torch.max(outputs.data, 1)
            total += modes.size(0)
            correct += (predicted == modes).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# Calculate and print the accuracy
accuracy = calculate_accuracy(dataloader, model)
print(f'Accuracy of the model on the test dataset: {accuracy:.2f}%')


Accuracy of the model on the test dataset: 100.00%
