In [None]:
# data_utils.py

import os
from typing import Tuple, Dict, List
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import albumentations as A
from torchvision import transforms
import numpy as np
import pandas as pd
from pathlib import Path

class ArtDataset(Dataset):
    """Custom Dataset for loading art images."""
    
    def __init__(
        self,
        data_dir: str,
        metadata_file: str,
        transform=None,
        mode: str = 'train'
    ):
        """
        Args:
            data_dir (str): Directory with all the images
            metadata_file (str): Path to the metadata CSV file
            transform: Optional transform to be applied on images
            mode (str): 'train', 'val', or 'test'
        """
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.mode = mode
        
        # Read metadata
        self.metadata = pd.read_csv(metadata_file)
        
        # Create class mapping
        self.class_to_idx = {
            cls: idx for idx, cls in enumerate(self.metadata['style'].unique())
        }
        
    def __len__(self) -> int:
        return len(self.metadata)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        # Get image path and label
        row = self.metadata.iloc[idx]
        img_path = self.data_dir / row['filename']
        label = self.class_to_idx[row['style']]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
            
        return image, label

def get_data_transforms(img_size: int = 224) -> Dict[str, transforms.Compose]:
    """
    Get train and validation transforms.
    
    Args:
        img_size (int): Size to resize images to
    
    Returns:
        Dict containing train and validation transforms
    """
    train_transform = A.Compose([
        A.RandomCrop(height=img_size, width=img_size),
        A.Rotate(limit=10),
        A.RandomBrightnessContrast(
            brightness_limit=0.1,
            contrast_limit=0.1
        ),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ])
    
    val_transform = A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ])
    
    return {'train': train_transform, 'val': val_transform}

def create_data_loaders(
    data_dir: str,
    metadata_file: str,
    batch_size: int = 32,
    num_workers: int = 4,
    img_size: int = 224
) -> Dict[str, DataLoader]:
    """
    Create train and validation data loaders.
    
    Args:
        data_dir (str): Directory containing images
        metadata_file (str): Path to metadata CSV
        batch_size (int): Batch size for training
        num_workers (int): Number of workers for data loading
        img_size (int): Size to resize images to
    
    Returns:
        Dict containing train and validation DataLoaders
    """
    transforms = get_data_transforms(img_size)
    
    datasets = {
        mode: ArtDataset(
            data_dir=data_dir,
            metadata_file=metadata_file,
            transform=transforms[mode],
            mode=mode
        )
        for mode in ['train', 'val']
    }
    
    dataloaders = {
        mode: DataLoader(
            datasets[mode],
            batch_size=batch_size,
            shuffle=(mode == 'train'),
            num_workers=num_workers,
            pin_memory=True
        )
        for mode in ['train', 'val']
    }
    
    return dataloaders

def get_class_weights(metadata_file: str) -> torch.Tensor:
    """
    Calculate class weights for imbalanced datasets.
    
    Args:
        metadata_file (str): Path to metadata CSV
    
    Returns:
        torch.Tensor: Class weights for weighted loss function
    """
    metadata = pd.read_csv(metadata_file)
    class_counts = metadata['style'].value_counts()
    total = len(metadata)
    
    weights = torch.tensor([
        total / (len(class_counts) * count)
        for count in class_counts
    ])
    
    return weights

# Usage example:
if __name__ == "__main__":
    # Example usage
    data_dir = "path/to/images"
    metadata_file = "path/to/metadata.csv"
    
    dataloaders = create_data_loaders(
        data_dir=data_dir,
        metadata_file=metadata_file,
        batch_size=32
    )
    
    # Example iteration
    for images, labels in dataloaders['train']:
        print(f"Batch shape: {images.shape}")
        print(f"Labels shape: {labels.shape}")
        break