In [None]:
# ============================================================
# Synthetic Nepali Braille Detection - YOLOv11 Enhanced Model
# ============================================================
# Based on: Improved YOLOv11 Architecture (C3k2_GBC, ULSAM, SDIoU)
# Dataset: Synthetic Nepali Braille Dataset (47 classes, grayscale)
# ============================================================

import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pandas as pd
from tqdm import tqdm

# ============================================================
# 1. Dataset Definition
# ============================================================

class BrailleDataset(Dataset):
    def __init__(self, root, csv_file='labels.csv', transform=None):
        self.root = root
        self.csv_path = os.path.join(root, csv_file)
        self.data = pd.read_csv(self.csv_path)
        self.transform = transform

        self.classes = sorted(self.data['latin'].unique())
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.root, row['filename'])
        if not os.path.exists(img_path):
            img_path = os.path.join(self.root, row['latin'], row['filename'].split('\\')[-1].split('/')[-1])
        image = Image.open(img_path).convert('L')  # grayscale
        if self.transform:
            image = self.transform(image)
        label = self.class_to_idx[row['latin']]
        return image, label


# ============================================================
# 2. Modules: Bottleneck, GBC, C3k2_GBC, ULSAM, SDIoU
# ============================================================

class BottConv(nn.Module):
    """PW–DW–PW sequence (1×1, 3×3 depthwise, 1×1)."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        mid_ch = out_ch // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, mid_ch, 1, bias=False),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_ch, mid_ch, 3, padding=1, groups=mid_ch, bias=False),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class GBC(nn.Module):
    """Gated Bottleneck Convolution (spatial + channel gating)."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.bottleneck = BottConv(in_ch, out_ch)
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_ch, out_ch, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat = self.bottleneck(x)
        mask = self.gate(feat)
        return feat * mask


class C3k2_GBC(nn.Module):
    """C3k2 module with Gated Bottleneck Convolution."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.main = nn.Sequential(
            GBC(in_ch, out_ch),
            GBC(out_ch, out_ch)
        )
        self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x):
        return self.main(x) + self.shortcut(x)


class ULSAM(nn.Module):
    """Ultra-Lightweight Subspace Attention Module."""
    def __init__(self, channels, subspaces=4):
        super().__init__()
        self.subspaces = subspaces
        self.split_channels = channels // subspaces
        self.attentions = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(self.split_channels, self.split_channels, 3, padding=1, groups=self.split_channels),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, stride=1, padding=1),
                nn.Conv2d(self.split_channels, 1, 1),
                nn.Softmax(dim=-1)
            ) for _ in range(subspaces)
        ])

    def forward(self, x):
        splits = torch.chunk(x, self.subspaces, dim=1)
        outs = []
        for s, att in zip(splits, self.attentions):
            att_map = att(s)
            outs.append(s * att_map + s)  # residual connection
        return torch.cat(outs, dim=1)


# ============================================================
# 3. SDIoU Loss Function
# ============================================================

class SDIoULoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.2):
        super().__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, pred_boxes, gt_boxes):
        # For simplicity here: dummy placeholder
        # since this version is classification-only fine-tuning.
        return torch.tensor(0.0, requires_grad=True).to(pred_boxes.device)


# ============================================================
# 4. YOLOv11 Braille Model (Backbone + Neck + Head)
# ============================================================

class YOLOv11_Braille(nn.Module):
    def __init__(self, num_classes=47):
        super().__init__()

        # Backbone (grayscale input)
        self.backbone = nn.Sequential(
            C3k2_GBC(1, 32),
            nn.MaxPool2d(2),
            C3k2_GBC(32, 64),
            nn.MaxPool2d(2),
            C3k2_GBC(64, 128),
            nn.MaxPool2d(2),
            C3k2_GBC(128, 256)
        )

        # ULSAM attention
        self.ulsam1 = ULSAM(256)
        self.ulsam2 = ULSAM(256)

        # Neck (simplified PANet-like)
        self.neck = nn.Sequential(
            nn.Conv2d(256, 128, 1),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Head (classification)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.ulsam1(x)
        x = self.ulsam2(x)
        x = self.neck(x)
        x = self.head(x)
        return x


# ============================================================
# 5. Training Pipeline
# ============================================================

def train_model(model, train_loader, val_loader, num_epochs=50, lr=0.001, device='cuda'):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.937, weight_decay=0.0005)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        print(f"Validation Accuracy: {acc:.2f}%")

    print("Training Complete ✅")
    return model


# ============================================================
# 6. Main
# ============================================================

if __name__ == "__main__":
    ROOT = "synthetic_braille_dataset"

    # Transforms
    train_tfms = transforms.Compose([
        transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2)], p=0.3),
        transforms.RandomRotation(5),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = BrailleDataset(root=ROOT, transform=train_tfms)

    # Split 90/10
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

    # Model
    model = YOLOv11_Braille(num_classes=len(dataset.classes))

    # Train
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    trained_model = train_model(model, train_loader, val_loader, num_epochs=50, lr=0.001, device=device)

    # Save
    torch.save(trained_model.state_dict(), "braille_yolov11.pth")
    print("✅ Model saved as 'braille_yolov11.pth'")


Epoch 1/50:   0%|          | 0/34 [00:00<?, ?it/s]