In [None]:
from torchvision import transforms
from PIL import Image
from torchvision import transforms
from PIL import Image

# Heavy augmentation for all non-Nepal countries
heavy_aug = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomRotation(12),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
    transforms.RandomApply([transforms.RandomPerspective(distortion_scale=0.5)], p=0.5),
    transforms.ToTensor()
])

# Light/basic augmentation for Nepal only
light_aug = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

country_map = {
    "nepali": 0,
    "indian": 1,
    "bangladesh": 2,
    "pakistan": 3,
    "USA": 4,
    "euro": 5
}

denomination_map = {
    "5": 0,
    "10": 1,
    "20": 2,
    "50": 3,
    "100": 4,
    "500": 5,
    "1000": 6,
    "2000": 7,
    "5000": 8,
    "2":9,
    "1":10,
    "200":11
}



In [None]:
import os
from torch.utils.data import Dataset

class CurrencyDataset(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        for country in os.listdir(root_dir):
            c_path = os.path.join(root_dir, country)
            for denom in os.listdir(c_path):
                d_path = os.path.join(c_path, denom)
                for img in os.listdir(d_path):
                    self.samples.append((
                        os.path.join(d_path, img),
                        country,
                        denom
                    ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, country, denom = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        # Heavy augmentation for all except Nepal
        if country != "nepali":
            img = heavy_aug(img)
        else:
            img = light_aug(img)

        c_label = country_map[country]
        d_label = denomination_map[denom]
        return img, c_label, d_label

