# Image Classification: Transfer Learning & Fine-tuning with PyTorch

Chương trình này tải hình ảnh từ một thư mục cục bộ (nhãn từ tên thư mục con), chia chúng thành các thư mục train/validation (nếu cần), huấn luyện một mô hình **pretrained**  (đã được huấn luyện trước) với tùy chọn **partial freezing**  (đóng băng một phần) các lớp, và báo cáo các số liệu (độ chính xác, độ chính xác, độ thu hồi, F1) cùng với một ma trận nhầm lẫn.

**Định dạng thư mục** (hai tùy chọn):
1. Thư mục gốc đơn với các thư mục con lớp (schương trình sẽ chia các class dựa vào tên thư mụa):
```
data_root/
class_a/ img1.jpg ...
class_b/ img2.jpg ...
```
2. Chia trước các thư mục `train/` và `val/` (chương trình sẽ đọc và sử dụng chúng (tên thư mục) trực tiếp để phân chia):
```
data_root/
train/ class_a/..., class_b/...
val/ class_a/..., class_b/...
Gửi ý kiến phản hồi
Bảng điều khiển bên
Các bản dịch đã thực hiện
Đã lưu
   ```

In [None]:
# ==== Configuration (edit me) ===============================================
from pathlib import Path

# Path to your dataset root (change this!)
DATA_ROOT = Path("data_root")  # e.g., Path(r"D:/datasets/flowers")

# If your DATA_ROOT already has 'train' and 'val' subfolders set this to False
NEED_RANDOM_SPLIT = True

# Proportions when NEED_RANDOM_SPLIT=True
TRAIN_RATIO = 0.8  # 80% train, 20% val
RANDOM_SEED = 42

# Choose a model: 'resnet18', 'resnet50', 'efficientnet_b0', 'mobilenet_v3_small'
MODEL_NAME = "resnet18"

# Image size (short side resize & center-crop/RandomResizedCrop)
IMAGE_SIZE = 224

# Batch size and epochs
BATCH_SIZE = 32
NUM_EPOCHS = 10

# Learning rates
LR_BACKBONE = 1e-4   # for frozen/unfrozen backbone (if frozen, optimizer ignores params)
LR_CLASSIFIER = 1e-3 # usually higher for the new head

# Freeze strategy:
# 'all_frozen' -> freeze all backbone, train only new classifier
# 'partial'    -> unfreeze last block(s) while earlier layers stay frozen
# 'none'       -> unfreeze everything (full fine-tuning)
FREEZE_STRATEGY = "partial"

# Number of last blocks to unfreeze for 'partial' (best-effort per architecture)
UNFREEZE_LAST_BLOCKS = 1

# Mixed precision training (speeds up on modern GPUs)
USE_AMP = True

# Workers for data loading
NUM_WORKERS = 2

# Where to save best model
OUTPUT_DIR = Path("./outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_PATH = OUTPUT_DIR / "best_model.pt"


In [None]:
# ==== Imports ================================================================
import os, math, time, itertools, json, random
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms, models

import numpy as np
import matplotlib.pyplot as plt

# Try to import sklearn (used for metrics); fallback to manual if unavailable
try:
    from sklearn.metrics import confusion_matrix, classification_report
    HAVE_SKLEARN = True
except Exception:
    HAVE_SKLEARN = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# ==== Data transforms ========================================================
# Augmentations for train, lighter transforms for val
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE + 32),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [None]:
# ==== Dataset & split logic ==================================================
def build_datasets(data_root: Path, need_random_split=True, train_ratio=0.8, seed=42):
    data_root = Path(data_root)
    if not data_root.exists():
        raise FileNotFoundError(f"DATA_ROOT not found: {data_root.resolve()}")
    
    has_train_val = (data_root / 'train').exists() and (data_root / 'val').exists()
    
    if has_train_val and not need_random_split:
        train_ds = datasets.ImageFolder(data_root / 'train', transform=train_transform)
        val_ds   = datasets.ImageFolder(data_root / 'val',   transform=val_transform)
        return train_ds, val_ds
    
    # Single folder: split randomly
    full_ds = datasets.ImageFolder(data_root, transform=None)  # transform later per split
    n_total = len(full_ds.samples)
    n_train = int(n_total * train_ratio)
    n_val   = n_total - n_train
    
    # Deterministic split
    g = torch.Generator().manual_seed(seed)
    idx_train, idx_val = torch.utils.data.random_split(range(n_total), [n_train, n_val], generator=g)
    
    # Wrap as Subset with proper transforms
    train_raw = Subset(datasets.ImageFolder(data_root, transform=train_transform), idx_train.indices)
    val_raw   = Subset(datasets.ImageFolder(data_root, transform=val_transform),   idx_val.indices)
    
    # Need to ensure classes/targets are consistent with ImageFolder mapping
    return train_raw, val_raw

train_ds, val_ds = build_datasets(DATA_ROOT, NEED_RANDOM_SPLIT, TRAIN_RATIO, RANDOM_SEED)

class_names = train_ds.dataset.classes if hasattr(train_ds, 'dataset') else train_ds.classes
num_classes = len(class_names)
print("Classes:", class_names)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
# ==== Model factory ==========================================================
def build_model(name: str, num_classes: int):
    name = name.lower()
    if name == 'resnet18':
        m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        in_feats = m.fc.in_features
        m.fc = nn.Linear(in_feats, num_classes)
        backbone = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4']
        head = ['fc']
    elif name == 'resnet50':
        m = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        in_feats = m.fc.in_features
        m.fc = nn.Linear(in_feats, num_classes)
        backbone = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4']
        head = ['fc']
    elif name == 'efficientnet_b0':
        m = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        in_feats = m.classifier[1].in_features
        m.classifier[1] = nn.Linear(in_feats, num_classes)
        backbone = ['features']
        head = ['classifier']
    elif name == 'mobilenet_v3_small':
        m = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
        in_feats = m.classifier[3].in_features
        m.classifier[3] = nn.Linear(in_feats, num_classes)
        backbone = ['features']
        head = ['classifier']
    else:
        raise ValueError(f"Unknown MODEL_NAME: {name}")
    return m, backbone, head

model, backbone_names, head_names = build_model(MODEL_NAME, num_classes)

# Freeze strategy
def set_freeze_strategy(model, backbone_names, head_names, strategy='partial', unfreeze_last_blocks=1):
    # First: freeze everything
    for p in model.parameters():
        p.requires_grad = False
    
    if strategy == 'all_frozen':
        # Unfreeze only head
        for name in head_names:
            for p in getattr(model, name).parameters():
                p.requires_grad = True
    elif strategy == 'partial':
        # Unfreeze head
        for name in head_names:
            for p in getattr(model, name).parameters():
                p.requires_grad = True
        # Unfreeze last N blocks (best-effort per architecture)
        # Works for resnet layer{1..4} and for features-based nets
        if hasattr(model, 'layer4'):  # resnet
            layers = [getattr(model, f'layer{i}') for i in range(1,5)]
            for block in layers[-unfreeze_last_blocks:]:
                for p in block.parameters():
                    p.requires_grad = True
        elif hasattr(model, 'features'):
            # Unfreeze last N stages in features (e.g., EfficientNet/MobileNet)
            total = len(model.features)
            start = max(0, total - unfreeze_last_blocks)
            for i in range(start, total):
                for p in model.features[i].parameters():
                    p.requires_grad = True
    elif strategy == 'none':
        # Unfreeze everything
        for p in model.parameters():
            p.requires_grad = True
    else:
        raise ValueError("FREEZE_STRATEGY must be one of {'all_frozen','partial','none'}")

set_freeze_strategy(model, backbone_names, head_names, FREEZE_STRATEGY, UNFREEZE_LAST_BLOCKS)
model = model.to(device)

# Count trainable parameters
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {n_trainable:,} / {n_total:,}")

In [None]:
# ==== Optimizer, loss, scheduler ============================================
# Set different LRs for backbone vs head when applicable
def param_groups(model, backbone_names, head_names, lr_backbone, lr_classifier):
    groups = []
    backbone_params = []
    head_params = []
    for name, module in model.named_children():
        if name in head_names:
            head_params += [p for p in module.parameters() if p.requires_grad]
        elif name in backbone_names:
            backbone_params += [p for p in module.parameters() if p.requires_grad]
        else:
            # modules like 'avgpool' etc. -> treat as backbone
            backbone_params += [p for p in module.parameters() if p.requires_grad]
    if backbone_params:
        groups.append({'params': backbone_params, 'lr': lr_backbone})
    if head_params:
        groups.append({'params': head_params, 'lr': lr_classifier})
    return groups

optimizer = optim.AdamW(param_groups(model, backbone_names, head_names, LR_BACKBONE, LR_CLASSIFIER), weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, NUM_EPOCHS-1))

In [None]:
# ==== Training & evaluation ==================================================
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

def train_one_epoch(model, loader, optimizer, scaler, device):
    model.train()
    running_loss, running_correct, n = 0.0, 0, 0
    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            outputs = model(images)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        preds = outputs.argmax(1)
        running_loss += loss.item() * images.size(0)
        running_correct += (preds == targets).sum().item()
        n += images.size(0)
    return running_loss / max(1, n), running_correct / max(1, n)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    all_preds, all_targets = [], []
    running_loss, running_correct, n = 0.0, 0, 0
    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, targets)
        preds = outputs.argmax(1)
        running_loss += loss.item() * images.size(0)
        running_correct += (preds == targets).sum().item()
        n += images.size(0)
        all_preds.append(preds.cpu())
        all_targets.append(targets.cpu())
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()
    avg_loss = running_loss / max(1, n)
    acc = running_correct / max(1, n)
    return avg_loss, acc, all_targets, all_preds

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_acc, best_state = -1.0, None

for epoch in range(1, NUM_EPOCHS+1):
    t0 = time.time()
    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, scaler, device)
    val_loss, val_acc, y_true, y_pred = evaluate(model, val_loader, device)
    scheduler.step()
    
    history['train_loss'].append(tr_loss)
    history['train_acc'].append(tr_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    if val_acc > best_acc:
        best_acc = val_acc
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}
        torch.save({'model_state': best_state,
                    'class_names': class_names,
                    'model_name': MODEL_NAME}, str(CHECKPOINT_PATH))
    
    dt = time.time() - t0
    print(f"Epoch {epoch:03d}/{NUM_EPOCHS} | "
          f"train_loss={tr_loss:.4f} acc={tr_acc:.4f} | "
          f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
          f"time={dt:.1f}s")

In [None]:
# ==== Plot training curves ===================================================
plt.figure()
plt.plot(history['train_loss'], label='train_loss')
plt.plot(history['val_loss'], label='val_loss')
plt.xlabel('epoch'); plt.ylabel('loss'); plt.title('Loss')
plt.legend()
plt.show()

plt.figure()
plt.plot(history['train_acc'], label='train_acc')
plt.plot(history['val_acc'], label='val_acc')
plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.title('Accuracy')
plt.legend()
plt.show()

In [None]:
# ==== Detailed metrics =======================================================
if HAVE_SKLEARN:
    print("\nClassification report (sklearn):\n")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
else:
    print("\nscikit-learn is not available. Computing basic metrics manually.\n")
    # Confusion matrix (manual)
    cm = np.zeros((num_classes, num_classes), dtype=int)
    for t, p in zip(y_true, y_pred):
        cm[t, p] += 1
    # Per-class precision/recall/F1 (macro)
    eps = 1e-12
    precs, recs, f1s = [], [], []
    for c in range(num_classes):
        tp = cm[c, c]
        fp = cm[:, c].sum() - tp
        fn = cm[c, :].sum() - tp
        precision = tp / max(1, tp + fp + eps)
        recall    = tp / max(1, tp + fn + eps)
        f1        = 2 * precision * recall / max(eps, precision + recall)
        precs.append(precision); recs.append(recall); f1s.append(f1)
    print("Per-class metrics:")
    for i, name in enumerate(class_names):
        print(f"{name:20s}  P={precs[i]:.4f}  R={recs[i]:.4f}  F1={f1s[i]:.4f}")
    print(f"\nMacro Precision={np.mean(precs):.4f}, Recall={np.mean(recs):.4f}, F1={np.mean(f1s):.4f}")

# Plot confusion matrix
plt.figure()
im = plt.imshow(cm, interpolation='nearest')
plt.title('Confusion Matrix')
plt.colorbar(im, fraction=0.046, pad=0.04)
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, class_names, rotation=45, ha='right')
plt.yticks(tick_marks, class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.show()

In [None]:
# ==== Inference utility ======================================================
from PIL import Image

def load_checkpoint(path, device=device):
    ckpt = torch.load(path, map_location=device)
    m, _, _ = build_model(ckpt['model_name'], num_classes=len(ckpt['class_names']))
    m.load_state_dict({k: v for k, v in ckpt['model_state'].items()})
    m.eval().to(device)
    return m, ckpt['class_names']

def predict_image(img_path, model, class_names, image_size=IMAGE_SIZE):
    tfm = transforms.Compose([
        transforms.Resize(image_size + 32),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    img = Image.open(img_path).convert('RGB')
    x = tfm(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
    idx = int(np.argmax(probs))
    return class_names[idx], float(probs[idx])

print("\nUsage example (after training):\n"
      "model, classes = load_checkpoint(CHECKPOINT_PATH)\n"
      "label, conf = predict_image('path/to/image.jpg', model, classes)\n"
      "print(label, conf)")