# 1. 3D Tensor dataset

In this notebook, we create a custom dataset class named Random3DTensorDataset that extends PyTorch's Dataset. This dataset generates random 3D tensors and corresponding random binary labels.

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader

class Random3DTensorDataset(Dataset):
    def __init__(self, num_samples: int, tensor_shape: tuple = (3, 32, 32)):
        """
        Args:
            num_samples (int): Number of samples in the dataset.
            tensor_shape (tuple): Shape of the 3D tensor (e.g., (channels, height, width)).
        """
        self.num_samples = num_samples
        self.tensor_shape = tensor_shape

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Create a random 3D tensor with the specified shape.
        tensor = torch.rand(self.tensor_shape)
        # Generate a random label: 0 or 1.
        label = torch.randint(0, 2, (1,)).item()
        return tensor, label

# Example usage:
if __name__ == "__main__":
    # Create a dataset with 100 samples; each sample is a 3x32x32 tensor.
    dataset = Random3DTensorDataset(num_samples=100, tensor_shape=(3, 32, 32))
    
    # Create a DataLoader to iterate through the dataset in batches.
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    
    # Get a single batch and print the shape of tensors and the labels.
    for batch_tensors, batch_labels in dataloader:
        print("Tensor batch shape:", batch_tensors.shape)  # Should be [10, 3, 32, 32]
        print("Labels:", batch_labels)                     # 10 labels (0 or 1)
        break


Tensor batch shape: torch.Size([10, 3, 32, 32])
Labels: tensor([1, 0, 0, 1, 1, 1, 1, 1, 0, 1])
