In [13]:
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import models

# Define the custom dataset class for handling the HDF5 data
class HDF5Dataset(Dataset):
    def __init__(self, file_path):
        self.file_path = file_path
        self.data = []
        with h5py.File(self.file_path, 'r') as file:
            data_group = file['data']  # Access the 'data' group
            for demo_key in [key for key in data_group.keys() if key.startswith('demo')]:
                images = data_group[demo_key]['obs']['corner2_image'][:]
                props = data_group[demo_key]['obs']['prop'][:]
                labels = data_group[demo_key]['mode'][:]
                # Reformat each image to match the expected input shape of [channels, height, width]
                for img, prop, label in zip(images, props, labels):
                    # Transpose the image to fit the expected order
                    img = img.transpose(1, 2, 0)  # Assuming original shape [height, channel, width]
                    self.data.append((img, prop, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, prop, label = self.data[idx]
        # Normalize and permute the image data
        img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) / 255.0  # Now in [channels, height, width]
        prop = torch.tensor(prop, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)
        return img, prop, label

# Path to the HDF5 file
file_path = 'release/data/metaworld/Assembly_frame_stack_1_96x96_end_on_success/dataset_mode.hdf5'

# Initialize dataset and dataloader
dataset = HDF5Dataset(file_path)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Modify ResNet-18 to include property data correctly
class CustomResNet(nn.Module):
    def __init__(self, num_props):
        super().__init__()
        # Initialize the base ResNet model with its pretrained weights
        base_model = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(base_model.children())[:-1])  # Exclude the original FC layer
        num_ftrs = base_model.fc.in_features  # The number of features output by the convolutional layers
        self.fc = nn.Sequential(
            nn.Linear(num_ftrs + num_props, 100),  # Combine image features with properties
            nn.ReLU(),
            nn.Linear(100, 2)  # For two class classification
        )

    def forward(self, x, props):
        x = self.features(x)  # Pass the image through the feature extractor
        x = x.view(x.size(0), -1)  # Flatten the features
        x = torch.cat((x, props), dim=1)  # Concatenate the flattened features with the properties
        x = self.fc(x)  # Pass the combined features through the fully connected layers
        return x

# Instantiate the custom model
custom_model = CustomResNet(num_props=4)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(custom_model.parameters(), lr=0.001, momentum=0.9)

# Training function
def train_model(model, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, props, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(images, props)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        epoch_loss = running_loss / len(dataset)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    print('Training complete')
    return model

# Train the model
trained_model = train_model(custom_model, criterion, optimizer, num_epochs=10)


  label = torch.tensor(label, dtype=torch.long)


Epoch 1/10, Loss: 0.3378
Epoch 2/10, Loss: 0.2127
Epoch 3/10, Loss: 0.0706
Epoch 4/10, Loss: 0.0266
Epoch 5/10, Loss: 0.1038
Epoch 6/10, Loss: 0.0355
Epoch 7/10, Loss: 0.0081
Epoch 8/10, Loss: 0.0683
Epoch 9/10, Loss: 0.0613
Epoch 10/10, Loss: 0.0709
Training complete


In [7]:
import h5py

# Open the HDF5 file and print its structure to ensure it contains the expected keys
def print_structure(file_path):
    with h5py.File(file_path, 'r') as file:
        print("Keys in the file:")
        for key in file.keys():
            print(key)
            for subkey in file[key].keys():
                print(f"  {subkey}: {list(file[key][subkey].keys())}")

# Call the function with the correct file path
print_structure('release/data/metaworld/Assembly_frame_stack_1_96x96_end_on_success/dataset_mode.hdf5')


Keys in the file:
data
  demo_0: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
  demo_1: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
  demo_2: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
  demo_3: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
  demo_4: ['actions', 'dones', 'mode', 'obs', 'rewards', 'states']
