# PyTorch Dataset Classes with Kaggle Integration

This notebook implements PyTorch dataset classes that automatically download datasets from Kaggle and provide proper `__getitem__` functionality for both detection and classification tasks.

## Install Required Packages

In [None]:
!pip install kagglehub torch torchvision albumentations opencv-python pillow numpy pandas

## Import Required Libraries

In [None]:
import os
import tempfile
import shutil
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import xml.etree.ElementTree as ET

import kagglehub
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

## BoundingBox Data Structure

In [None]:
@dataclass
class BoundingBox:
    """Represents a bounding box with coordinates and label."""
    x_min: float
    y_min: float
    x_max: float
    y_max: float
    label: str
    confidence: float = 1.0
    
    def to_dict(self) -> Dict:
        """Convert to dictionary format."""
        return {
            'x_min': self.x_min,
            'y_min': self.y_min,
            'x_max': self.x_max,
            'y_max': self.y_max,
            'label': self.label,
            'confidence': self.confidence
        }
    
    def area(self) -> float:
        """Calculate bounding box area."""
        return (self.x_max - self.x_min) * (self.y_max - self.y_min)
    
    def to_albumentations_format(self) -> List[float]:
        """Convert to albumentations format [x_min, y_min, x_max, y_max]."""
        return [self.x_min, self.y_min, self.x_max, self.y_max]

## Kaggle Dataset Manager

In [None]:
class KaggleDatasetManager:
    """Manages downloading and caching of Kaggle datasets."""
    
    def __init__(self, cache_dir: Optional[str] = None):
        """Initialize the manager with optional cache directory."""
        if cache_dir is None:
            self.cache_dir = Path(tempfile.gettempdir()) / "kaggle_datasets_cache"
        else:
            self.cache_dir = Path(cache_dir)
        
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        print(f"Cache directory: {self.cache_dir}")
    
    def download_dataset(self, dataset_handle: str, force_download: bool = False) -> Path:
        """Download a dataset from Kaggle and return the path."""
        dataset_name = dataset_handle.replace('/', '_')
        cached_path = self.cache_dir / dataset_name
        
        if cached_path.exists() and not force_download:
            print(f"Using cached dataset: {cached_path}")
            return cached_path
        
        print(f"Downloading dataset: {dataset_handle}")
        try:
            # Download to temporary location first
            temp_path = kagglehub.dataset_download(dataset_handle)
            
            # Move to our cache directory
            if cached_path.exists():
                shutil.rmtree(cached_path)
            shutil.move(temp_path, cached_path)
            
            print(f"Dataset downloaded and cached: {cached_path}")
            return cached_path
            
        except Exception as e:
            print(f"Error downloading dataset {dataset_handle}: {e}")
            raise
    
    def clear_cache(self):
        """Clear the entire cache directory."""
        if self.cache_dir.exists():
            shutil.rmtree(self.cache_dir)
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            print("Cache cleared")
    
    def list_cached_datasets(self) -> List[str]:
        """List all cached datasets."""
        if not self.cache_dir.exists():
            return []
        return [d.name for d in self.cache_dir.iterdir() if d.is_dir()]

## Data Transforms

In [None]:
class DetectionTransforms:
    """Transforms for object detection tasks with bounding boxes."""
    
    @staticmethod
    def get_train_transforms(image_size: Tuple[int, int] = (416, 416)):
        """Get training transforms for detection."""
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.HueSaturationValue(p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ], bbox_params=A.BboxParams(
            format='pascal_voc',
            label_fields=['class_labels'],
            min_visibility=0.3
        ))
    
    @staticmethod
    def get_val_transforms(image_size: Tuple[int, int] = (416, 416)):
        """Get validation transforms for detection."""
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ], bbox_params=A.BboxParams(
            format='pascal_voc',
            label_fields=['class_labels']
        ))


class ClassificationTransforms:
    """Transforms for classification tasks."""
    
    @staticmethod
    def get_train_transforms(image_size: Tuple[int, int] = (224, 224)):
        """Get training transforms for classification."""
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.HueSaturationValue(p=0.3),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    
    @staticmethod
    def get_val_transforms(image_size: Tuple[int, int] = (224, 224)):
        """Get validation transforms for classification."""
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

## Detection Dataset Class

In [None]:
class AndrewMVDPyTorchDataset(Dataset):
    """PyTorch Dataset for Andrew MVD face detection dataset with XML annotations."""
    
    def __init__(self, 
                 kaggle_handle: str = "andrewmvd/face-mask-detection",
                 split: str = "train",
                 transform=None,
                 train_ratio: float = 0.8,
                 val_ratio: float = 0.1,
                 test_ratio: float = 0.1,
                 cache_dir: Optional[str] = None,
                 force_download: bool = False):
        """
        Initialize the dataset.
        
        Args:
            kaggle_handle: Kaggle dataset handle
            split: Dataset split ('train', 'val', 'test')
            transform: Albumentations transform pipeline
            train_ratio: Ratio for training split
            val_ratio: Ratio for validation split
            test_ratio: Ratio for test split
            cache_dir: Directory to cache downloaded datasets
            force_download: Force re-download even if cached
        """
        self.split = split
        self.transform = transform
        
        # Download dataset
        self.manager = KaggleDatasetManager(cache_dir)
        self.dataset_path = self.manager.download_dataset(kaggle_handle, force_download)
        
        # Find images and annotations
        self.images_dir = self.dataset_path / "images"
        self.annotations_dir = self.dataset_path / "annotations"
        
        if not self.images_dir.exists():
            raise FileNotFoundError(f"Images directory not found: {self.images_dir}")
        if not self.annotations_dir.exists():
            raise FileNotFoundError(f"Annotations directory not found: {self.annotations_dir}")
        
        # Get all image files
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
        all_images = [f for f in self.images_dir.iterdir() 
                     if f.suffix.lower() in image_extensions]
        
        # Filter images that have corresponding XML annotations
        self.image_files = []
        for img_file in all_images:
            xml_file = self.annotations_dir / f"{img_file.stem}.xml"
            if xml_file.exists():
                self.image_files.append(img_file)
        
        if not self.image_files:
            raise ValueError("No images with corresponding XML annotations found")
        
        # Sort for consistent ordering
        self.image_files.sort()
        
        # Split dataset
        total_samples = len(self.image_files)
        train_end = int(total_samples * train_ratio)
        val_end = train_end + int(total_samples * val_ratio)
        
        if split == "train":
            self.image_files = self.image_files[:train_end]
        elif split == "val":
            self.image_files = self.image_files[train_end:val_end]
        elif split == "test":
            self.image_files = self.image_files[val_end:]
        else:
            raise ValueError(f"Invalid split: {split}. Must be 'train', 'val', or 'test'")
        
        print(f"Loaded {len(self.image_files)} images for {split} split")
        
        # Create label mapping
        self.label_to_idx = {'with_mask': 0, 'without_mask': 1, 'mask_weared_incorrect': 2}
        self.idx_to_label = {v: k for k, v in self.label_to_idx.items()}
    
    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(self.image_files)
    
    def _parse_xml_annotation(self, xml_path: Path) -> List[BoundingBox]:
        """Parse XML annotation file and return list of bounding boxes."""
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        bboxes = []
        for obj in root.findall('object'):
            name = obj.find('name').text
            bbox_elem = obj.find('bndbox')
            
            x_min = float(bbox_elem.find('xmin').text)
            y_min = float(bbox_elem.find('ymin').text)
            x_max = float(bbox_elem.find('xmax').text)
            y_max = float(bbox_elem.find('ymax').text)
            
            bbox = BoundingBox(
                x_min=x_min,
                y_min=y_min,
                x_max=x_max,
                y_max=y_max,
                label=name
            )
            bboxes.append(bbox)
        
        return bboxes
    
    def __getitem__(self, idx: int) -> Dict:
        """Get a single data point."""
        if idx >= len(self.image_files):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.image_files)}")
        
        # Load image
        img_path = self.image_files[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load annotations
        xml_path = self.annotations_dir / f"{img_path.stem}.xml"
        bboxes = self._parse_xml_annotation(xml_path)
        
        # Prepare data for transforms
        if bboxes:
            bbox_coords = [bbox.to_albumentations_format() for bbox in bboxes]
            class_labels = [bbox.label for bbox in bboxes]
        else:
            bbox_coords = []
            class_labels = []
        
        # Apply transforms
        if self.transform:
            transformed = self.transform(
                image=image,
                bboxes=bbox_coords,
                class_labels=class_labels
            )
            image = transformed['image']
            bbox_coords = transformed['bboxes']
            class_labels = transformed['class_labels']
        else:
            # Convert to tensor if no transforms
            image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
        
        # Convert labels to indices
        class_indices = [self.label_to_idx.get(label, 0) for label in class_labels]
        
        return {
            'image': image,
            'bboxes': torch.tensor(bbox_coords, dtype=torch.float32) if bbox_coords else torch.empty((0, 4)),
            'labels': torch.tensor(class_indices, dtype=torch.long) if class_indices else torch.empty((0,), dtype=torch.long),
            'image_path': str(img_path),
            'image_id': idx
        }

## Classification Dataset Class

In [None]:
class Face12kPyTorchDataset(Dataset):
    """PyTorch Dataset for Face12k classification dataset."""
    
    def __init__(self,
                 kaggle_handle: str = "ashishjangra27/face-mask-12k-images-dataset",
                 split: str = "train",
                 transform=None,
                 train_ratio: float = 0.8,
                 val_ratio: float = 0.1,
                 test_ratio: float = 0.1,
                 cache_dir: Optional[str] = None,
                 force_download: bool = False):
        """
        Initialize the dataset.
        
        Args:
            kaggle_handle: Kaggle dataset handle
            split: Dataset split ('train', 'val', 'test')
            transform: Albumentations transform pipeline
            train_ratio: Ratio for training split
            val_ratio: Ratio for validation split
            test_ratio: Ratio for test split
            cache_dir: Directory to cache downloaded datasets
            force_download: Force re-download even if cached
        """
        self.split = split
        self.transform = transform
        
        # Download dataset
        self.manager = KaggleDatasetManager(cache_dir)
        self.dataset_path = self.manager.download_dataset(kaggle_handle, force_download)
        
        # Find the main dataset directory
        possible_dirs = ['Face Mask Dataset', 'dataset', 'data']
        main_dir = None
        
        for dir_name in possible_dirs:
            potential_path = self.dataset_path / dir_name
            if potential_path.exists():
                main_dir = potential_path
                break
        
        if main_dir is None:
            # Use the dataset path directly
            main_dir = self.dataset_path
        
        # Look for class directories
        class_dirs = [d for d in main_dir.iterdir() if d.is_dir()]
        
        if not class_dirs:
            raise FileNotFoundError(f"No class directories found in {main_dir}")
        
        # Collect all images with their labels
        self.samples = []
        self.classes = sorted([d.name for d in class_dirs])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
        
        for class_dir in class_dirs:
            class_name = class_dir.name
            class_idx = self.class_to_idx[class_name]
            
            for img_file in class_dir.iterdir():
                if img_file.suffix.lower() in image_extensions:
                    self.samples.append((img_file, class_idx, class_name))
        
        if not self.samples:
            raise ValueError("No images found in the dataset")
        
        # Sort for consistent ordering
        self.samples.sort(key=lambda x: str(x[0]))
        
        # Split dataset
        total_samples = len(self.samples)
        train_end = int(total_samples * train_ratio)
        val_end = train_end + int(total_samples * val_ratio)
        
        if split == "train":
            self.samples = self.samples[:train_end]
        elif split == "val":
            self.samples = self.samples[train_end:val_end]
        elif split == "test":
            self.samples = self.samples[val_end:]
        else:
            raise ValueError(f"Invalid split: {split}. Must be 'train', 'val', or 'test'")
        
        print(f"Loaded {len(self.samples)} images for {split} split")
        print(f"Classes: {self.classes}")
    
    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Dict:
        """Get a single data point."""
        if idx >= len(self.samples):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.samples)}")
        
        img_path, class_idx, class_name = self.samples[idx]
        
        # Load image
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply transforms
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        else:
            # Convert to tensor if no transforms
            image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
        
        return {
            'image': image,
            'label': torch.tensor(class_idx, dtype=torch.long),
            'class_name': class_name,
            'image_path': str(img_path),
            'image_id': idx
        }

## Test the Dataset Classes

In [None]:
# Test Detection Dataset
print("Testing Detection Dataset...")
try:
    # Create transforms
    train_transforms = DetectionTransforms.get_train_transforms()
    val_transforms = DetectionTransforms.get_val_transforms()
    
    # Create dataset
    detection_dataset = AndrewMVDPyTorchDataset(
        split="train",
        transform=train_transforms,
        train_ratio=0.8,
        val_ratio=0.1,
        test_ratio=0.1
    )
    
    print(f"Detection dataset size: {len(detection_dataset)}")
    
    # Test __getitem__
    sample = detection_dataset[0]
    print(f"Sample keys: {sample.keys()}")
    print(f"Image shape: {sample['image'].shape}")
    print(f"Number of bboxes: {len(sample['bboxes'])}")
    print(f"Number of labels: {len(sample['labels'])}")
    print(f"Labels: {sample['labels']}")
    
    print("Detection dataset test PASSED!")
    
except Exception as e:
    print(f"Detection dataset test FAILED: {e}")

In [None]:
# Test Classification Dataset
print("\nTesting Classification Dataset...")
try:
    # Create transforms
    train_transforms = ClassificationTransforms.get_train_transforms()
    val_transforms = ClassificationTransforms.get_val_transforms()
    
    # Create dataset
    classification_dataset = Face12kPyTorchDataset(
        split="train",
        transform=train_transforms,
        train_ratio=0.8,
        val_ratio=0.1,
        test_ratio=0.1
    )
    
    print(f"Classification dataset size: {len(classification_dataset)}")
    
    # Test __getitem__
    sample = classification_dataset[0]
    print(f"Sample keys: {sample.keys()}")
    print(f"Image shape: {sample['image'].shape}")
    print(f"Label: {sample['label']}")
    print(f"Class name: {sample['class_name']}")
    
    print("Classification dataset test PASSED!")
    
except Exception as e:
    print(f"Classification dataset test FAILED: {e}")

## Usage Examples

In [None]:
# Example: Create DataLoader for training
from torch.utils.data import DataLoader

# Detection dataset with DataLoader
detection_train_dataset = AndrewMVDPyTorchDataset(
    split="train",
    transform=DetectionTransforms.get_train_transforms()
)

detection_train_loader = DataLoader(
    detection_train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    collate_fn=lambda batch: {
        'images': torch.stack([item['image'] for item in batch]),
        'bboxes': [item['bboxes'] for item in batch],
        'labels': [item['labels'] for item in batch],
        'image_paths': [item['image_path'] for item in batch],
        'image_ids': [item['image_id'] for item in batch]
    }
)

print(f"Detection DataLoader created with {len(detection_train_loader)} batches")

# Classification dataset with DataLoader
classification_train_dataset = Face12kPyTorchDataset(
    split="train",
    transform=ClassificationTransforms.get_train_transforms()
)

classification_train_loader = DataLoader(
    classification_train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2
)

print(f"Classification DataLoader created with {len(classification_train_loader)} batches")

## Summary

This notebook provides:

1. **KaggleDatasetManager**: Automatically downloads and caches Kaggle datasets to temporary directories
2. **AndrewMVDPyTorchDataset**: PyTorch dataset class for object detection with XML annotations
3. **Face12kPyTorchDataset**: PyTorch dataset class for image classification
4. **Transform classes**: Albumentations-based data augmentation for both tasks
5. **BoundingBox dataclass**: Structured representation for detection annotations

### Key Features:
- ✅ Proper `__getitem__` methods that return formatted data points
- ✅ Automatic Kaggle dataset downloading to temporary files
- ✅ Intelligent caching system to avoid re-downloading
- ✅ Train/validation/test splits with configurable ratios
- ✅ Data augmentation with Albumentations
- ✅ Compatible with PyTorch DataLoader
- ✅ Self-contained implementation (no external imports from main codebase)

### Usage:
1. Run all cells in order
2. The datasets will be automatically downloaded from Kaggle
3. Use the dataset classes in your training loops
4. Datasets are cached locally to avoid re-downloading