# üéØ Train Unified Classifier - C·∫£i thi·ªán v·ªõi EMNIST, Focal Loss v√† RGB Shapes

H·ªá th·ªëng nh·∫≠n di·ªán m·ªü r·ªông:
- **10 ch·ªØ s·ªë**: 0-9 (MNIST + EMNIST digits)
- **52 ch·ªØ c√°i**: A-Z, a-z (EMNIST letters)
- **9 h√¨nh h·ªçc**: Circle, Triangle, Square, Pentagon, Hexagon, Heptagon, Octagon, Nonagon, Star

## C·∫£i ti·∫øn m·ªõi:
- ‚úÖ **MNIST + EMNIST**: TƒÉng d·ªØ li·ªáu digits v√† th√™m letters
- ‚úÖ **Focal Loss + Class Weight**: Gi·∫£m l·ªói nh·∫ßm (4‚Üî9, 3‚Üî5, 6‚Üî5, 1‚Üî7)
- ‚úÖ **RGB cho Shapes**: Nh·∫≠n di·ªán ƒë∆∞·ª£c background v√† fill m√†u kh√°c nhau
- ‚úÖ **Augmentation m·∫°nh h∆°n**: Rotation 45¬∞, Perspective, ColorJitter m·∫°nh cho shapes
- ‚úÖ **Top-2 Prediction Logic**: X·ª≠ l√Ω c√°c tr∆∞·ªùng h·ª£p nh·∫ßm trong top-2
- ‚úÖ **INPUT_SIZE: 128x128** (ƒë·ªÉ ph√¢n bi·ªát t·ªët h∆°n)

---


In [1]:
# Import Libraries
import os
import json
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchvision.datasets import EMNIST
from PIL import Image
from sklearn.model_selection import train_test_split

print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")


‚úÖ PyTorch version: 2.5.1+cu121
‚úÖ CUDA available: True
‚úÖ GPU: NVIDIA GeForce RTX 4050 Laptop GPU


## üìã Configuration


In [2]:
class Config:
    # Paths
    MNIST_TRAIN_DIR = 'mnist_competition/train'
    MNIST_TRAIN_CSV = 'mnist_competition/train_label.csv'
    SHAPES_DIR = 'Shapes_Classifier/dataset/output'
    EMNIST_DATA_DIR = './data/emnist'
    
    # Training
    EPOCHS = 25  # Increased from 20
    BATCH_SIZE = 64
    LEARNING_RATE = 1e-4
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Model
    # Classes: 0-9 (digits), 10-35 (A-Z), 36-61 (a-z), 62-70 (shapes)
    NUM_CLASSES = 71  # 10 + 26 + 26 + 9
    INPUT_SIZE = 128  # Increased from 64 to better distinguish high-edge shapes
    
    # Output
    MODEL_PATH = 'unified_model_71classes_improved.pth'
    LABEL_MAPPING_PATH = 'label_mapping_71classes.json'
    
    # Class weights cho c√°c s·ªë d·ªÖ nh·∫ßm (d·ª±a tr√™n ph√¢n t√≠ch l·ªói)
    DIGIT_WEIGHTS = {
        0: 1.0,   # 0
        1: 2.0,   # 1 - d·ªÖ nh·∫ßm v·ªõi 7 (65.5%)
        2: 1.5,   # 2 - d·ªÖ nh·∫ßm v·ªõi 3, 1
        3: 2.0,   # 3 - d·ªÖ nh·∫ßm v·ªõi 5 (50%)
        4: 2.5,   # 4 - d·ªÖ nh·∫ßm v·ªõi 9 (75%) - nghi√™m tr·ªçng nh·∫•t
        5: 2.0,   # 5 - d·ªÖ nh·∫ßm v·ªõi 3 (67.7%)
        6: 2.0,   # 6 - d·ªÖ nh·∫ßm v·ªõi 5 (62.7%)
        7: 1.5,   # 7 - d·ªÖ nh·∫ßm v·ªõi 3 (41.7%)
        8: 1.5,   # 8 - d·ªÖ nh·∫ßm nhi·ªÅu s·ªë
        9: 2.5,   # 9 - d·ªÖ nh·∫ßm v·ªõi 4 (48.5%) - nghi√™m tr·ªçng nh·∫•t
    }

print("="*60)
print("CONFIGURATION")
print("="*60)
print(f"Device: {Config.DEVICE}")
print(f"Epochs: {Config.EPOCHS}")
print(f"Batch size: {Config.BATCH_SIZE}")
print(f"Learning rate: {Config.LEARNING_RATE}")
print(f"Input size: {Config.INPUT_SIZE}x{Config.INPUT_SIZE}")
print(f"Num classes: {Config.NUM_CLASSES}")
print(f"  - Digits (0-9): 10")
print(f"  - Letters A-Z (10-35): 26")
print(f"  - Letters a-z (36-61): 26")
print(f"  - Shapes (62-70): 9")
print("="*60)


CONFIGURATION
Device: cuda
Epochs: 25
Batch size: 64
Learning rate: 0.0001
Input size: 128x128
Num classes: 71
  - Digits (0-9): 10
  - Letters A-Z (10-35): 26
  - Letters a-z (36-61): 26
  - Shapes (62-70): 9


In [3]:
print("Loading datasets...")

# Load MNIST
mnist_df = pd.read_csv(Config.MNIST_TRAIN_CSV)
mnist_df['source'] = 'mnist'  # Th√™m c·ªôt source ƒë·ªÉ ph√¢n bi·ªát v·ªõi EMNIST
print(f"‚úÖ MNIST: {len(mnist_df)} images")

# Load EMNIST
print("\nüìÇ Loading EMNIST...")
emnist_data = []
emnist_letters_data = []
os.makedirs('emnist_images', exist_ok=True)

try:
    # EMNIST balanced: 47 classes (0-9 digits, 10-35 A-Z, 36-61 a-z)
    emnist_train = EMNIST(
        root=Config.EMNIST_DATA_DIR,
        split='balanced',
        train=True,
        download=True,
        transform=None
    )
    emnist_test = EMNIST(
        root=Config.EMNIST_DATA_DIR,
        split='balanced',
        train=False,
        download=True,
        transform=None
    )
    
    print(f"‚úÖ EMNIST train: {len(emnist_train)} images")
    print(f"‚úÖ EMNIST test: {len(emnist_test)} images")
    
    print("\nüìù Converting EMNIST to image files...")
    for dataset, split_name in [(emnist_train, 'train'), (emnist_test, 'test')]:
        for idx, (img, label) in enumerate(tqdm(dataset, desc=f"Processing {split_name}")):
            img_path = f'emnist_images/emnist_{split_name}_{idx:06d}.png'
            img.save(img_path)
            
            label_int = int(label)
            if label_int < 10:
                # Digits: merge v·ªõi MNIST
                emnist_data.append({
                    'image_name': os.path.basename(img_path),
                    'label': label_int,
                    'source': 'emnist'
                })
            else:
                # Letters: gi·ªØ nguy√™n label (10-61)
                emnist_letters_data.append({
                    'image_name': os.path.basename(img_path),
                    'label': label_int,
                    'source': 'emnist'
                })
    
    emnist_digits_df = pd.DataFrame(emnist_data)
    emnist_letters_df = pd.DataFrame(emnist_letters_data)
    
    # Merge MNIST + EMNIST digits
    combined_digits_df = pd.concat([mnist_df, emnist_digits_df], ignore_index=True)
    print(f"\n‚úÖ Combined digits: {len(combined_digits_df)} images (MNIST + EMNIST)")
    print(f"‚úÖ EMNIST letters: {len(emnist_letters_df)} images")
    
except Exception as e:
    print(f"‚ùå Error loading EMNIST: {e}")
    print("‚ö†Ô∏è  S·ª≠ d·ª•ng ch·ªâ MNIST (kh√¥ng c√≥ EMNIST)")
    combined_digits_df = mnist_df.copy()
    emnist_letters_df = pd.DataFrame(columns=['image_name', 'label', 'source'])

# Load Shapes
shape_files = [f for f in os.listdir(Config.SHAPES_DIR) if f.endswith('.png')]
shape_labels = [f.split('_')[0] for f in shape_files]
shapes_df = pd.DataFrame({'image_name': shape_files, 'label': shape_labels})
print(f"‚úÖ Shapes: {len(shapes_df)} images")

# Create label mapping
shape_names = sorted(shapes_df['label'].unique())
shape_to_id = {name: idx + 62 for idx, name in enumerate(shape_names)}  # 62-70

# T·∫°o id_to_label mapping ƒë·∫ßy ƒë·ªß
id_to_label = {}
# Digits 0-9
for i in range(10):
    id_to_label[str(i)] = str(i)
# Letters A-Z (10-35)
for i, letter in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ'):
    id_to_label[str(i + 10)] = letter
# Letters a-z (36-61)
for i, letter in enumerate('abcdefghijklmnopqrstuvwxyz'):
    id_to_label[str(i + 36)] = letter
# Shapes (62-70)
for name, class_id in shape_to_id.items():
    id_to_label[str(class_id)] = name

print(f"\nüìã Label Mapping ({len(id_to_label)} classes):")
for class_id, label_name in sorted(id_to_label.items()):
    print(f"   Class {int(class_id):2d}: {label_name}")

# Save label mapping
with open(Config.LABEL_MAPPING_PATH, 'w') as f:
    json.dump(id_to_label, f, indent=2)
print(f"\n‚úÖ Saved {Config.LABEL_MAPPING_PATH}")


Loading datasets...
‚úÖ MNIST: 60000 images

üìÇ Loading EMNIST...
Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to ./data/emnist\EMNIST\raw\gzip.zip


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 562M/562M [00:50<00:00, 11.2MB/s] 


Extracting ./data/emnist\EMNIST\raw\gzip.zip to ./data/emnist\EMNIST\raw
‚úÖ EMNIST train: 112800 images
‚úÖ EMNIST test: 18800 images

üìù Converting EMNIST to image files...


Processing train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 112800/112800 [01:26<00:00, 1305.16it/s]
Processing test: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 18800/18800 [00:14<00:00, 1277.50it/s]



‚úÖ Combined digits: 88000 images (MNIST + EMNIST)
‚úÖ EMNIST letters: 103600 images
‚úÖ Shapes: 90000 images

üìã Label Mapping (71 classes):
   Class  0: 0
   Class  1: 1
   Class 10: A
   Class 11: B
   Class 12: C
   Class 13: D
   Class 14: E
   Class 15: F
   Class 16: G
   Class 17: H
   Class 18: I
   Class 19: J
   Class  2: 2
   Class 20: K
   Class 21: L
   Class 22: M
   Class 23: N
   Class 24: O
   Class 25: P
   Class 26: Q
   Class 27: R
   Class 28: S
   Class 29: T
   Class  3: 3
   Class 30: U
   Class 31: V
   Class 32: W
   Class 33: X
   Class 34: Y
   Class 35: Z
   Class 36: a
   Class 37: b
   Class 38: c
   Class 39: d
   Class  4: 4
   Class 40: e
   Class 41: f
   Class 42: g
   Class 43: h
   Class 44: i
   Class 45: j
   Class 46: k
   Class 47: l
   Class 48: m
   Class 49: n
   Class  5: 5
   Class 50: o
   Class 51: p
   Class 52: q
   Class 53: r
   Class 54: s
   Class 55: t
   Class 56: u
   Class 57: v
   Class 58: w
   Class 59: x
   Class  6: 6


In [4]:
class UnifiedDataset(Dataset):
    """
    Unified dataset for digits, letters and geometric shapes.
    - Digits/Letters: Grayscale (convert to RGB 3 channels)
    - Shapes: RGB (keep color to recognize background/fill)
    """
    
    def __init__(self, digits_df, letters_df, shapes_df, 
                 mnist_dir, emnist_dir, shapes_dir,
                 shape_to_id, transform=None, sample_fraction=0.67):
        """
        Args:
            digits_df: DataFrame with digits data (MNIST + EMNIST)
            letters_df: DataFrame with letters data (EMNIST)
            shapes_df: DataFrame with shapes data
            mnist_dir: Directory with MNIST images
            emnist_dir: Directory with EMNIST images
            shapes_dir: Directory with shape images
            shape_to_id: Mapping from shape name to class ID (62-70)
            transform: Image transforms
            sample_fraction: Fraction of shapes to use (balance with digits)
        """
        self.data_list = []
        
        # Add digits (0-9)
        if digits_df is not None and len(digits_df) > 0:
            for idx, row in digits_df.iterrows():
                # Ki·ªÉm tra source ƒë·ªÉ x√°c ƒë·ªãnh th∆∞ m·ª•c ƒë√∫ng
                source = row.get('source', 'mnist')  # M·∫∑c ƒë·ªãnh l√† 'mnist' n·∫øu kh√¥ng c√≥
                if source == 'mnist':
                    img_path = os.path.join(mnist_dir, row['image_name'])
                else:  # emnist
                    img_path = os.path.join(emnist_dir, row['image_name'])
                
                # Ki·ªÉm tra file t·ªìn t·∫°i tr∆∞·ªõc khi th√™m
                if os.path.exists(img_path):
                    self.data_list.append({
                        'path': img_path,
                        'label': int(row['label']),  # 0-9
                        'source': 'digit'
                    })
                else:
                    print(f"‚ö†Ô∏è  Warning: File not found: {img_path}")
        
        # Add letters (10-61)
        if letters_df is not None and len(letters_df) > 0:
            for idx, row in letters_df.iterrows():
                img_path = os.path.join(emnist_dir, row['image_name'])
                # Ki·ªÉm tra file t·ªìn t·∫°i tr∆∞·ªõc khi th√™m
                if os.path.exists(img_path):
                    self.data_list.append({
                        'path': img_path,
                        'label': int(row['label']),  # 10-61
                        'source': 'letter'
                    })
                else:
                    print(f"‚ö†Ô∏è  Warning: File not found: {img_path}")
        
        # Add shapes (62-70)
        if shapes_df is not None:
            shapes_df_sampled = shapes_df.sample(frac=sample_fraction, random_state=42)
            for idx, row in shapes_df_sampled.iterrows():
                img_path = os.path.join(shapes_dir, row['image_name'])
                # Ki·ªÉm tra file t·ªìn t·∫°i tr∆∞·ªõc khi th√™m
                if os.path.exists(img_path):
                    self.data_list.append({
                        'path': img_path,
                        'label': shape_to_id[row['label']],  # 62-70
                        'source': 'shape'
                    })
                else:
                    print(f"‚ö†Ô∏è  Warning: File not found: {img_path}")
        
        self.transform = transform
        
        # Statistics
        digit_count = sum(1 for item in self.data_list if item['source'] == 'digit')
        letter_count = sum(1 for item in self.data_list if item['source'] == 'letter')
        shape_count = sum(1 for item in self.data_list if item['source'] == 'shape')
        
        print(f"‚úÖ Dataset created: {len(self.data_list)} images")
        print(f"   - Digits: {digit_count} images (classes 0-9)")
        print(f"   - Letters: {letter_count} images (classes 10-61)")
        print(f"   - Shapes: {shape_count} images (classes 62-70)")
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        item = self.data_list[idx]
        
        # Shapes: Keep RGB to recognize background/fill colors
        # Digits/Letters: Grayscale (will convert to RGB in transform)
        if item['source'] == 'shape':
            image = Image.open(item['path']).convert('RGB')  # Keep RGB for shapes
        else:
            image = Image.open(item['path']).convert('L')  # Grayscale for digits/letters
        
        if self.transform:
            image = self.transform(image, is_shape=(item['source'] == 'shape'))
        
        return image, item['label']

print("‚úÖ UnifiedDataset class defined")


‚úÖ UnifiedDataset class defined


## üèãÔ∏è Training Functions


In [5]:
class FocalLoss(nn.Module):
    """
    Focal Loss ƒë·ªÉ t·∫≠p trung v√†o c√°c m·∫´u kh√≥ ph√¢n bi·ªát.
    Gi√∫p gi·∫£m l·ªói nh·∫ßm gi·ªØa c√°c s·ªë t∆∞∆°ng t·ª± (4‚Üî9, 3‚Üî5, 6‚Üî5, 1‚Üî7).
    """
    def __init__(self, alpha=None, gamma=2.0, weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            if isinstance(self.alpha, (list, np.ndarray)):
                alpha_t = torch.tensor(self.alpha, device=inputs.device)
                alpha_t = alpha_t[targets]
                focal_loss = alpha_t * focal_loss
        
        return focal_loss.mean()

def create_class_weights():
    """T·∫°o class weights d·ª±a tr√™n ph√¢n t√≠ch l·ªói."""
    weights = torch.ones(Config.NUM_CLASSES)
    
    # Digits (0-9): √Åp d·ª•ng weights t·ª´ Config
    for digit, weight in Config.DIGIT_WEIGHTS.items():
        weights[digit] = weight
    
    return weights.to(Config.DEVICE)

def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({
            'loss': f"{running_loss/(pbar.n+1):.4f}",
            'acc': f"{100.*correct/total:.2f}%"
        })
    
    return running_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device, use_top2=False):
    """Validate model v·ªõi option s·ª≠ d·ª•ng top-2 logic."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc='Validation')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            if use_top2:
                # S·ª≠ d·ª•ng top-2 logic
                batch_preds = []
                for i in range(images.size(0)):
                    pred, _, _ = predict_with_top2(model, images[i], device)
                    batch_preds.append(pred.item())
                predicted = torch.tensor(batch_preds, device=device)
            else:
                _, predicted = outputs.max(1)
            
            running_loss += loss.item()
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f"{running_loss/(pbar.n+1):.4f}",
                'acc': f"{100.*correct/total:.2f}%"
            })
    
    return running_loss / len(loader), 100. * correct / total

print("‚úÖ Training functions defined")


‚úÖ Training functions defined


## üîÑ B∆∞·ªõc 4: Augmentation m·∫°nh h∆°n


In [6]:
class AdaptiveTransform:
    """
    Transform th√≠ch ·ª©ng: 
    - Digits/Letters: Grayscale -> RGB (3 channels gi·ªëng nhau)
    - Shapes: RGB (gi·ªØ nguy√™n m√†u, augmentation m√†u m·∫°nh ƒë·ªÉ h·ªçc background/fill)
    """
    def __init__(self, is_train=True):
        self.is_train = is_train
        
    def __call__(self, image, is_shape=False):
        # Resize tr∆∞·ªõc
        image = transforms.Resize((Config.INPUT_SIZE, Config.INPUT_SIZE))(image)
        
        if is_shape:
            # Shapes: RGB v·ªõi augmentation m√†u m·∫°nh ƒë·ªÉ h·ªçc c√°c m√†u kh√°c nhau
            if self.is_train:
                image = transforms.RandomRotation(45)(image)
                image = transforms.RandomAffine(
                    degrees=0,
                    translate=(0.2, 0.2),
                    scale=(0.8, 1.2),
                    shear=10
                )(image)
                image = transforms.RandomPerspective(distortion_scale=0.3, p=0.5)(image)
                # Color augmentation M·∫†NH cho shapes ƒë·ªÉ h·ªçc m√†u background/fill
                image = transforms.ColorJitter(
                    brightness=0.5,    # TƒÉng ƒë·ªÉ h·ªçc c√°c m√†u background kh√°c nhau
                    contrast=0.5,      # TƒÉng ƒë·ªÉ h·ªçc c√°c m√†u fill kh√°c nhau
                    saturation=0.5,    # Th√™m saturation cho shapes
                    hue=0.1            # Th√™m hue variation
                )(image)
                image = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))(image)
        else:
            # Digits/Letters: Grayscale -> RGB
            image = transforms.Grayscale(num_output_channels=3)(image)
            
            if self.is_train:
                image = transforms.RandomRotation(45)(image)
                image = transforms.RandomAffine(
                    degrees=0,
                    translate=(0.2, 0.2),
                    scale=(0.8, 1.2),
                    shear=10
                )(image)
                image = transforms.RandomPerspective(distortion_scale=0.3, p=0.5)(image)
                # Color jitter nh·∫π h∆°n cho grayscale
                image = transforms.ColorJitter(brightness=0.4, contrast=0.4)(image)
                image = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))(image)
        
        # ToTensor v√† Normalize
        image = transforms.ToTensor()(image)
        image = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )(image)
        
        return image

# T·∫°o transforms
train_transform = AdaptiveTransform(is_train=True)
val_transform = AdaptiveTransform(is_train=False)

# Split data
digits_train, digits_val = train_test_split(
    combined_digits_df, test_size=0.15, random_state=42, 
    stratify=combined_digits_df['label'] if 'label' in combined_digits_df.columns else None
)

if len(emnist_letters_df) > 0:
    letters_train, letters_val = train_test_split(
        emnist_letters_df, test_size=0.15, random_state=42,
        stratify=emnist_letters_df['label'] if 'label' in emnist_letters_df.columns else None
    )
else:
    letters_train = pd.DataFrame(columns=['image_name', 'label', 'source'])
    letters_val = pd.DataFrame(columns=['image_name', 'label', 'source'])

shapes_train, shapes_val = train_test_split(
    shapes_df, test_size=0.15, random_state=42, stratify=shapes_df['label']
)

print(f"üìä Data Split:")
print(f"   Train: Digits {len(digits_train)} + Letters {len(letters_train)} + Shapes ~{int(len(shapes_train)*0.67)}")
print(f"   Val:   Digits {len(digits_val)} + Letters {len(letters_val)} + Shapes ~{int(len(shapes_val)*0.67)}")


üìä Data Split:
   Train: Digits 74800 + Letters 88060 + Shapes ~51255
   Val:   Digits 13200 + Letters 15540 + Shapes ~9045


## üéØ Top-2 Prediction Logic (Optional)


In [7]:
def apply_top2_logic(predicted, top2_indices, top2_probs, threshold=0.1):
    """X·ª≠ l√Ω c√°c tr∆∞·ªùng h·ª£p nh·∫ßm trong top-2 predictions."""
    prob_diff = top2_probs[0] - top2_probs[1]
    
    if prob_diff < threshold:
        pred_class = top2_indices[0].item()
        second_class = top2_indices[1].item()
        
        confusion_pairs = {
            (4, 9): 'mirror', (9, 4): 'mirror',
            (3, 5): 'top_curve', (5, 3): 'top_curve',
            (6, 5): 'bottom_curve', (5, 6): 'bottom_curve',
            (1, 7): 'diagonal', (7, 1): 'diagonal',
            (0, 8): 'thickness', (8, 0): 'thickness',
        }
        
        if (pred_class, second_class) in confusion_pairs:
            return predicted
    
    return predicted

def predict_with_top2(model, image, device, threshold=0.1):
    """Predict v·ªõi x·ª≠ l√Ω top-2 logic."""
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        outputs = model(image)
        probs = torch.softmax(outputs, dim=1)
        top2_probs, top2_indices = probs.topk(2, dim=1)
        predicted = apply_top2_logic(
            top2_indices[0, 0], top2_indices[0], top2_probs[0], threshold
        )
        return predicted, top2_probs[0], top2_indices[0]

print("‚úÖ Top-2 prediction logic defined")


‚úÖ Top-2 prediction logic defined


## üé≤ Create Datasets & DataLoaders


In [8]:
# Create datasets
train_dataset = UnifiedDataset(
    digits_train, letters_train, shapes_train,
    Config.MNIST_TRAIN_DIR, 'emnist_images', Config.SHAPES_DIR,
    shape_to_id, transform=train_transform,
    sample_fraction=0.67
)

val_dataset = UnifiedDataset(
    digits_val, letters_val, shapes_val,
    Config.MNIST_TRAIN_DIR, 'emnist_images', Config.SHAPES_DIR,
    shape_to_id, transform=val_transform,
    sample_fraction=0.67
)

# DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=Config.BATCH_SIZE,
    shuffle=True, num_workers=0, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=Config.BATCH_SIZE,
    shuffle=False, num_workers=0, pin_memory=True
)

print(f"\n‚úÖ DataLoaders ready")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")


‚úÖ Dataset created: 214115 images
   - Digits: 74800 images (classes 0-9)
   - Letters: 88060 images (classes 10-61)
   - Shapes: 51255 images (classes 62-70)
‚úÖ Dataset created: 37785 images
   - Digits: 13200 images (classes 0-9)
   - Letters: 15540 images (classes 10-61)
   - Shapes: 9045 images (classes 62-70)

‚úÖ DataLoaders ready
   Train batches: 3346
   Val batches: 591


## üß† Load Model: EfficientNet-B0


In [9]:
print("Loading EfficientNet-B0...")
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)

# Modify classifier for 71 classes
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, Config.NUM_CLASSES)
model = model.to(Config.DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úÖ Model ready")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model on device: {next(model.parameters()).device}")


Loading EfficientNet-B0...
‚úÖ Model ready
   Total parameters: 4,098,499
   Trainable parameters: 4,098,499
   Model on device: cuda:0


## ‚öôÔ∏è Training Setup


In [10]:
# T·∫°o class weights
class_weights = create_class_weights()

# Focal Loss v·ªõi class weights
criterion = FocalLoss(alpha=None, gamma=2.0, weight=class_weights)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)

# Scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True
)

print(f"‚úÖ Training setup complete")
print(f"   Criterion: FocalLoss (gamma=2.0) v·ªõi class weights")
print(f"   Optimizer: Adam (lr={Config.LEARNING_RATE})")
print(f"   Scheduler: ReduceLROnPlateau (factor=0.5, patience=3)")
print(f"\nClass weights (digits):")
for i in range(10):
    print(f"   Class {i}: {class_weights[i]:.2f}x")


‚úÖ Training setup complete
   Criterion: FocalLoss (gamma=2.0) v·ªõi class weights
   Optimizer: Adam (lr=0.0001)
   Scheduler: ReduceLROnPlateau (factor=0.5, patience=3)

Class weights (digits):
   Class 0: 1.00x
   Class 1: 2.00x
   Class 2: 1.50x
   Class 3: 2.00x
   Class 4: 2.50x
   Class 5: 2.00x
   Class 6: 2.00x
   Class 7: 1.50x
   Class 8: 1.50x
   Class 9: 2.50x




## üöÄ Training Loop

**‚ö†Ô∏è Warning:** This will take ~1-2 hours depending on your GPU!


In [11]:
best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

print("="*60)
print("STARTING TRAINING")
print("="*60 + "\n")

for epoch in range(Config.EPOCHS):
    print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
    print("-" * 60)
    
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, Config.DEVICE
    )
    
    # Validation kh√¥ng d√πng top-2 logic (ƒë·ªÉ so s√°nh c√¥ng b·∫±ng)
    val_loss, val_acc = validate(
        model, val_loader, criterion, Config.DEVICE, use_top2=False
    )
    
    # Validation v·ªõi top-2 logic (ƒë·ªÉ xem c·∫£i thi·ªán)
    val_loss_top2, val_acc_top2 = validate(
        model, val_loader, criterion, Config.DEVICE, use_top2=True
    )
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    scheduler.step(val_acc)
    
    print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    print(f"Val Acc (top-2): {val_acc_top2:.2f}% (c·∫£i thi·ªán: +{val_acc_top2 - val_acc:.2f}%)")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_acc_top2': val_acc_top2,
            'label_mapping': id_to_label,
            'config': {
                'epochs': Config.EPOCHS,
                'batch_size': Config.BATCH_SIZE,
                'lr': Config.LEARNING_RATE,
                'input_size': Config.INPUT_SIZE,
                'num_classes': Config.NUM_CLASSES
            }
        }, Config.MODEL_PATH)
        print(f"‚úÖ Saved best model: {Config.MODEL_PATH}")
        print(f"   Val Acc: {val_acc:.2f}% | Val Acc (top-2): {val_acc_top2:.2f}%")

print(f"\n{'='*60}")
print(f"TRAINING COMPLETED!")
print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
print(f"Model saved: {Config.MODEL_PATH}")
print(f"{'='*60}")


STARTING TRAINING


Epoch 1/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:12:17<00:00,  1.30s/it, loss=0.7308, acc=73.02%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [08:57<00:00,  1.10it/s, loss=0.2149, acc=87.37%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [12:15<00:00,  1.24s/it, loss=0.2149, acc=87.38%]



Train Loss: 0.7308 | Train Acc: 73.02%
Val Loss: 0.2149 | Val Acc: 87.37%
Val Acc (top-2): 87.38% (c·∫£i thi·ªán: +0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 87.37% | Val Acc (top-2): 87.38%

Epoch 2/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:11:15<00:00,  1.28s/it, loss=0.2721, acc=85.24%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [07:28<00:00,  1.32it/s, loss=0.1402, acc=90.56%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [12:11<00:00,  1.24s/it, loss=0.1402, acc=90.56%]



Train Loss: 0.2721 | Train Acc: 85.24%
Val Loss: 0.1402 | Val Acc: 90.56%
Val Acc (top-2): 90.56% (c·∫£i thi·ªán: +0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 90.56% | Val Acc (top-2): 90.56%

Epoch 3/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:05:28<00:00,  1.17s/it, loss=0.2188, acc=87.43%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [07:41<00:00,  1.28it/s, loss=0.1251, acc=91.46%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:51<00:00,  1.20s/it, loss=0.1251, acc=91.45%]



Train Loss: 0.2188 | Train Acc: 87.43%
Val Loss: 0.1251 | Val Acc: 91.46%
Val Acc (top-2): 91.45% (c·∫£i thi·ªán: +-0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 91.46% | Val Acc (top-2): 91.45%

Epoch 4/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [58:53<00:00,  1.06s/it, loss=0.1912, acc=88.57%] 
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [02:23<00:00,  4.11it/s, loss=0.1124, acc=92.17%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:34<00:00,  1.18s/it, loss=0.1124, acc=92.17%]



Train Loss: 0.1912 | Train Acc: 88.57%
Val Loss: 0.1124 | Val Acc: 92.17%
Val Acc (top-2): 92.17% (c·∫£i thi·ªán: +0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 92.17% | Val Acc (top-2): 92.17%

Epoch 5/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [36:39<00:00,  1.52it/s, loss=0.1739, acc=89.33%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [02:19<00:00,  4.23it/s, loss=0.1105, acc=92.38%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:34<00:00,  1.17s/it, loss=0.1105, acc=92.38%]



Train Loss: 0.1739 | Train Acc: 89.33%
Val Loss: 0.1105 | Val Acc: 92.38%
Val Acc (top-2): 92.38% (c·∫£i thi·ªán: +-0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 92.38% | Val Acc (top-2): 92.38%

Epoch 6/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [37:25<00:00,  1.49it/s, loss=0.1592, acc=90.01%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [02:17<00:00,  4.30it/s, loss=0.1010, acc=92.90%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:30<00:00,  1.17s/it, loss=0.1010, acc=92.90%]



Train Loss: 0.1592 | Train Acc: 90.01%
Val Loss: 0.1010 | Val Acc: 92.90%
Val Acc (top-2): 92.90% (c·∫£i thi·ªán: +0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 92.90% | Val Acc (top-2): 92.90%

Epoch 7/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [52:13<00:00,  1.07it/s, loss=0.1488, acc=90.41%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [07:57<00:00,  1.24it/s, loss=0.0967, acc=93.24%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:34<00:00,  1.17s/it, loss=0.0967, acc=93.24%]



Train Loss: 0.1488 | Train Acc: 90.41%
Val Loss: 0.0967 | Val Acc: 93.24%
Val Acc (top-2): 93.24% (c·∫£i thi·ªán: +-0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 93.24% | Val Acc (top-2): 93.24%

Epoch 8/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:10:53<00:00,  1.27s/it, loss=0.1440, acc=90.62%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [09:34<00:00,  1.03it/s, loss=0.0956, acc=93.00%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [13:25<00:00,  1.36s/it, loss=0.0956, acc=93.00%]



Train Loss: 0.1440 | Train Acc: 90.62%
Val Loss: 0.0956 | Val Acc: 93.00%
Val Acc (top-2): 93.00% (c·∫£i thi·ªán: +0.00%)

Epoch 9/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:07:56<00:00,  1.22s/it, loss=0.1373, acc=91.01%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [08:51<00:00,  1.11it/s, loss=0.0940, acc=93.44%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:34<00:00,  1.18s/it, loss=0.0940, acc=93.44%]



Train Loss: 0.1373 | Train Acc: 91.01%
Val Loss: 0.0940 | Val Acc: 93.44%
Val Acc (top-2): 93.44% (c·∫£i thi·ªán: +-0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 93.44% | Val Acc (top-2): 93.44%

Epoch 10/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:07:19<00:00,  1.21s/it, loss=0.1303, acc=91.26%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [07:19<00:00,  1.34it/s, loss=0.0909, acc=93.64%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [05:24<00:00,  1.82it/s, loss=0.0909, acc=93.64%]



Train Loss: 0.1303 | Train Acc: 91.26%
Val Loss: 0.0909 | Val Acc: 93.64%
Val Acc (top-2): 93.64% (c·∫£i thi·ªán: +-0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 93.64% | Val Acc (top-2): 93.64%

Epoch 11/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:01:54<00:00,  1.11s/it, loss=0.1268, acc=91.45%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:09<00:00,  1.13s/it, loss=0.0920, acc=93.61%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [20:56<00:00,  2.13s/it, loss=0.0920, acc=93.60%]



Train Loss: 0.1268 | Train Acc: 91.45%
Val Loss: 0.0920 | Val Acc: 93.61%
Val Acc (top-2): 93.60% (c·∫£i thi·ªán: +-0.01%)

Epoch 12/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:53:36<00:00,  2.04s/it, loss=0.1232, acc=91.63%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [04:06<00:00,  2.40it/s, loss=0.0897, acc=93.63%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [20:17<00:00,  2.06s/it, loss=0.0897, acc=93.62%]



Train Loss: 0.1232 | Train Acc: 91.63%
Val Loss: 0.0897 | Val Acc: 93.63%
Val Acc (top-2): 93.62% (c·∫£i thi·ªán: +-0.01%)

Epoch 13/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:00:48<00:00,  1.09s/it, loss=0.1193, acc=91.83%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [04:04<00:00,  2.42it/s, loss=0.0867, acc=93.90%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [22:05<00:00,  2.24s/it, loss=0.0867, acc=93.90%]



Train Loss: 0.1193 | Train Acc: 91.83%
Val Loss: 0.0867 | Val Acc: 93.90%
Val Acc (top-2): 93.90% (c·∫£i thi·ªán: +0.00%)
‚úÖ Saved best model: unified_model_71classes_improved.pth
   Val Acc: 93.90% | Val Acc (top-2): 93.90%

Epoch 14/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [55:29<00:00,  1.01it/s, loss=0.1167, acc=91.93%] 
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [07:23<00:00,  1.33it/s, loss=0.0881, acc=93.69%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:44<00:00,  1.19s/it, loss=0.0881, acc=93.69%]



Train Loss: 0.1167 | Train Acc: 91.93%
Val Loss: 0.0881 | Val Acc: 93.69%
Val Acc (top-2): 93.69% (c·∫£i thi·ªán: +-0.00%)

Epoch 15/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:07:41<00:00,  1.21s/it, loss=0.1136, acc=92.05%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [07:51<00:00,  1.25it/s, loss=0.0894, acc=93.56%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [11:40<00:00,  1.18s/it, loss=0.0894, acc=93.56%]



Train Loss: 0.1136 | Train Acc: 92.05%
Val Loss: 0.0894 | Val Acc: 93.56%
Val Acc (top-2): 93.56% (c·∫£i thi·ªán: +-0.00%)

Epoch 16/25
------------------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3346/3346 [1:03:52<00:00,  1.15s/it, loss=0.1104, acc=92.25%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [10:15<00:00,  1.04s/it, loss=0.0871, acc=93.80%]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 591/591 [19:06<00:00,  1.94s/it, loss=0.0871, acc=93.80%]



Train Loss: 0.1104 | Train Acc: 92.25%
Val Loss: 0.0871 | Val Acc: 93.80%
Val Acc (top-2): 93.80% (c·∫£i thi·ªán: +-0.00%)

Epoch 17/25
------------------------------------------------------------


Training:   2%|‚ñè         | 58/3346 [01:12<1:08:25,  1.25s/it, loss=0.1067, acc=92.56%]


KeyboardInterrupt: 

## üìà Visualize Training History


In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot Loss
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot Accuracy
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training & Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history_improved.png', dpi=150, bbox_inches='tight')
print("‚úÖ Saved training_history_improved.png")
plt.show()


## üéØ Next Steps

After training completes, run evaluation:

```python
# Run this in terminal or new notebook
!python evaluate_model.py
```

Expected improvements:
- **Digits accuracy**: Gi·∫£m l·ªói nh·∫ßm 4‚Üî9, 3‚Üî5, 6‚Üî5, 1‚Üî7
- **Shapes accuracy**: Nh·∫≠n di·ªán ƒë∆∞·ª£c v·ªõi background v√† fill m√†u kh√°c nhau
- **Letters**: Nh·∫≠n di·ªán ƒë∆∞·ª£c A-Z v√† a-z

Then test on your sample image:

```python
!python pipeline.py --image Shapes_Classifier/Sample.png --output Sample_result_new.png
```

---

**‚úÖ Training notebook created successfully!**
