In [None]:
# In dataset.ipynb (code cell)
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
import numpy as np

class SinhalaDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """Initialize the dataset for Sinhala character images.
        
        Args:
            data_dir (str): Directory containing .png files named by character (e.g., a.png).
            transform (callable, optional): Transformations to apply to images.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.data = []
        self.label_map = {}  # Map character labels to integers
        
        if not os.path.exists(data_dir):
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        # Build label map and data list from .png filenames
        for idx, img_name in enumerate(sorted(os.listdir(data_dir))):
            if img_name.endswith('.png'):
                img_path = os.path.join(data_dir, img_name)
                if os.path.isfile(img_path):
                    # Extract character label from filename (e.g., 'a.png' -> 'a')
                    char_label = os.path.splitext(img_name)[0]
                    self.label_map[char_label] = idx
                    self.data.append({'image_path': img_path, 'label': char_label})
        
        if not self.data:
            raise ValueError(f"No valid .png images found in {data_dir}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # Load image in grayscale
        img = cv2.imread(item['image_path'], cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise ValueError(f"Failed to load image: {item['image_path']}")
        # Convert to float32 and normalize to [0, 1]
        img = img.astype(np.float32) / 255.0
        # Add channel dimension (1, H, W) for PyTorch
        img = np.expand_dims(img, axis=0)
        img = torch.from_numpy(img)
        
        if self.transform:
            img = self.transform(img)
        
        # Convert label to integer
        label = self.label_map[item['label']]
        return img, label

def get_dataloaders(data_dir, batch_size=32, train_split=0.8, val_split=0.1, transform=None):
    """Create DataLoaders for training, validation, and testing.
    
    Args:
        data_dir (str): Directory containing the dataset.
        batch_size (int): Number of samples per batch.
        train_split (float): Proportion of data for training.
        val_split (float): Proportion of data for validation.
        transform (callable, optional): Transformations to apply to images.
    
    Returns:
        tuple: (train_loader, val_loader, test_loader, label_map)
    """
    # Initialize dataset
    dataset = SinhalaDataset(data_dir, transform=transform)
    
    # Get dataset size and indices
    dataset_size = len(dataset)
    if dataset_size == 0:
        raise ValueError(f"Dataset is empty in {data_dir}")
    indices = list(range(dataset_size))
    np.random.seed(42)  # For reproducibility
    np.random.shuffle(indices)
    
    # Calculate split indices
    train_split_idx = int(np.floor(train_split * dataset_size))
    val_split_idx = int(np.floor((train_split + val_split) * dataset_size))
    
    # Split indices
    train_indices = indices[:train_split_idx]
    val_indices = indices[train_split_idx:val_split_idx]
    test_indices = indices[val_split_idx:]
    
    # Create samplers
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    test_sampler = SubsetRandomSampler(test_indices)
    
    # Create DataLoaders
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
    
    # Return the label map
    label_map = dataset.label_map
    
    return train_loader, val_loader, test_loader, label_map