In [None]:
from torchvision import transforms
from PIL import Image
from torchvision import transforms
from PIL import Image

# Heavy augmentation for all non-Nepal countries
heavy_aug = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomRotation(12),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
    transforms.RandomApply([transforms.RandomPerspective(distortion_scale=0.5)], p=0.5),
    transforms.ToTensor()
])

# Light/basic augmentation for Nepal only
light_aug = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

country_map = {
    "nepali": 0,
    "indian": 1,
    "bangladesh": 2,
    "pakistan": 3,
    "USA": 4,
    "euro": 5
}

denomination_map = {
    "5": 0,
    "10": 1,
    "20": 2,
    "50": 3,
    "100": 4,
    "500": 5,
    "1000": 6,
    "2000": 7,
    "5000": 8,
    "2":9,
    "1":10,
    "200":11
}



In [None]:
import os
from torch.utils.data import Dataset

class CurrencyDataset(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        for country in os.listdir(root_dir):
            c_path = os.path.join(root_dir, country)
            for denom in os.listdir(c_path):
                d_path = os.path.join(c_path, denom)
                for img in os.listdir(d_path):
                    self.samples.append((
                        os.path.join(d_path, img),
                        country,
                        denom
                    ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, country, denom = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        # Heavy augmentation for all except Nepal
        if country != "nepali":
            img = heavy_aug(img)
        else:
            img = light_aug(img)

        c_label = country_map[country]
        d_label = denomination_map[denom]
        return img, c_label, d_label



In [None]:
import os
from torch.utils.data import Dataset

class CurrencyDataset(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        for country in os.listdir(root_dir):
            c_path = os.path.join(root_dir, country)
            for denom in os.listdir(c_path):
                d_path = os.path.join(c_path, denom)
                for img in os.listdir(d_path):
                    self.samples.append((
                        os.path.join(d_path, img),
                        country,
                        denom
                    ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, country, denom = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        # Heavy augmentation for all except Nepal
        if country != "nepali":
            img = heavy_aug(img)
        else:
            img = light_aug(img)

        c_label = country_map[country]
        d_label = denomination_map[denom]
        return img, c_label, d_label



In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

device = "cuda" if torch.cuda.is_available() else "cpu"

num_countries = len(country_map)    # e.g., 6
num_denoms    = len(denomination_map)  # e.g., 9

class MultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(backbone.children())[:-1])  # remove final fc
        self.country_head = nn.Linear(backbone.fc.in_features, num_countries)
        self.denom_head   = nn.Linear(backbone.fc.in_features, num_denoms)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten
        country_out = self.country_head(x)
        denom_out   = self.denom_head(x)
        return country_out, denom_out

model = MultiTaskModel().to(device)


In [None]:
import torch.nn.functional as F
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=1e-4)

def compute_loss(country_pred, denom_pred, country_label, denom_label):
    loss_country = F.cross_entropy(country_pred, country_label)
    loss_denom   = F.cross_entropy(denom_pred, denom_label)
    return loss_country + loss_denom  # equal weighting


In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0
    
    for imgs, country_labels, denom_labels in train_loader:
        imgs = imgs.to(device)
        country_labels = country_labels.to(device)
        denom_labels   = denom_labels.to(device)
        
        optimizer.zero_grad()
        c_pred, d_pred = model(imgs)
        loss = compute_loss(c_pred, d_pred, country_labels, denom_labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    
    # Validation step
    model.eval()
    correct_country = 0
    correct_denom   = 0
    total = 0
    with torch.no_grad():
        for imgs, country_labels, denom_labels in val_loader:
            imgs = imgs.to(device)
            country_labels = country_labels.to(device)
            denom_labels   = denom_labels.to(device)
            
            c_pred, d_pred = model(imgs)
            _, c_pred_labels = torch.max(c_pred, 1)
            _, d_pred_labels = torch.max(d_pred, 1)
            
            total += imgs.size(0)
            correct_country += (c_pred_labels == country_labels).sum().item()
            correct_denom   += (d_pred_labels == denom_labels).sum().item()
    
    print(f"Val Accuracy - Country: {correct_country/total:.4f}, Denomination: {correct_denom/total:.4f}")
