In [None]:
# -*- coding: utf-8 -*-
"""
Faster R-CNN Training/Validation/Testing on SVHN for Digit Recognition (Google Colab)
Task 1: BBox and Class (category_id) prediction
Task 2: Full number sequence recognition (post-processing)
"""

import os
import json
import time
import numpy as np
import torch
import torchvision
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
try:
    from torchvision.transforms.v2 import functional as F, ToTensor, Compose
    using_v2_transforms = True
except ImportError:
    print("torchvision.transforms.v2 not found, falling back to v1 transforms.")
    print("Consider upgrading torchvision: pip install --upgrade torchvision")
    from torchvision import transforms as T # Fallback to v1
    using_v2_transforms = False

from torch.utils.data import Dataset, DataLoader
from PIL import Image
from collections import defaultdict

try:
    from pycocotools.coco import COCO
except ImportError:
    print("pycocotools not found. Installing...")
    !pip install pycocotools --quiet
    from pycocotools.coco import COCO

# --- 1. Configuration ---

from google.colab import drive
drive.mount('/content/drive')

import zipfile

zip_path = '/content/drive/MyDrive/SVHN_dataset.zip'
BASE_DIR = 'data/your_dataset'

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(BASE_DIR)



TRAIN_IMG_DIR = os.path.join(BASE_DIR, 'train')
VAL_IMG_DIR = os.path.join(BASE_DIR, 'valid')
TEST_IMG_DIR = os.path.join(BASE_DIR, 'test')

TRAIN_ANNOTATION_PATH = os.path.join(BASE_DIR, 'train.json')
VAL_ANNOTATION_PATH = os.path.join(BASE_DIR, 'valid.json')
TEST_ANNOTATION_PATH = os.path.join(BASE_DIR, 'test.json')

# --- Output Files ---
TASK1_OUTPUT_PATH = 'pred.json' # Task 1 輸出檔名 (BBox & Class)
TASK2_OUTPUT_PATH = 'pred.csv' # Task 2 輸出檔名 (Number Sequence)
MODEL_SAVE_PATH = 'fasterrcnn_svhn_best_model.pth'

# --- 模型與訓練參數 ---
# Category ID Mapping:
# 1: '0', 2: '1', 3: '2', 4: '3', 5: '4',
# 6: '5', 7: '6', 8: '7', 9: '8', 10: '9'
NUM_CLASSES = 11  # 10 個數字 + 1 個背景
NUM_EPOCHS = 10
BATCH_SIZE = 8
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9
LR_STEP_SIZE = 1
LR_GAMMA = 0.1
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
PRINT_FREQ = 100
BEST_VAL_LOSS = float('inf')

# --- Task 2 Parameters ---
TASK2_SCORE_THRESHOLD = 0.5 # THRESHOLD for Task 2

print(f"Using device: {DEVICE}")
print(f"Using torchvision v2 transforms: {using_v2_transforms}")



# --- 2. 資料集類別 (SVHNDataset) ---
class SVHNDataset(Dataset):
    def __init__(self, img_dir, annotation_path, transforms=None, is_test=False):
        self.img_dir = img_dir
        self.coco = COCO(annotation_path)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.transforms = transforms
        self.is_test = is_test
        print(f"Loaded {len(self.ids)} images from {annotation_path}. Is test set: {self.is_test}")

    def __getitem__(self, index):
        coco = self.coco
        img_id = self.ids[index]
        img_info = coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])

        try:
            img = Image.open(img_path).convert("RGB")
        except FileNotFoundError:
            print(f"Error: Image file not found at {img_path}")
            return None

        target = {}
        target["image_id"] = torch.tensor([img_id])

        if not self.is_test:
            ann_ids = coco.getAnnIds(imgIds=img_id)
            coco_anns = coco.loadAnns(ann_ids)

            boxes = []
            labels = []
            areas = []
            iscrowd = []

            for ann in coco_anns:
                xmin, ymin, w, h = ann['bbox']
                xmax = xmin + w
                ymax = ymin + h

                img_w, img_h = img.size
                xmin = max(0, xmin)
                ymin = max(0, ymin)
                xmax = min(img_w, xmax)
                ymax = min(img_h, ymax)

                if xmax <= xmin or ymax <= ymin:
                    continue

                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(ann['category_id'])
                areas.append(ann.get('area', w * h))
                iscrowd.append(ann.get('iscrowd', 0))

            target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32) if boxes else torch.empty((0, 4), dtype=torch.float32)
            target["labels"] = torch.as_tensor(labels, dtype=torch.int64) if labels else torch.empty((0,), dtype=torch.int64)
            target["area"] = torch.as_tensor(areas, dtype=torch.float32) if areas else torch.empty((0,), dtype=torch.float32)
            target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64) if iscrowd else torch.empty((0,), dtype=torch.int64)
        else:
            target["boxes"] = torch.empty((0, 4), dtype=torch.float32)
            target["labels"] = torch.empty((0,), dtype=torch.int64)

        if self.transforms is not None:
            if using_v2_transforms:
                img, target = self.transforms(img, target)
            else:
                img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.ids)


# --- 3. 資料轉換與 DataLoader ---
def get_transform(train):
    if using_v2_transforms:
        transforms = []
        transforms.append(ToTensor())
        return Compose(transforms)
    else: # v1 transforms
        transforms = []
        transforms.append(T.ToTensor())
        return T.Compose(transforms)

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None, None
    return tuple(zip(*batch))

print("Creating Datasets...")
dataset_train = SVHNDataset(TRAIN_IMG_DIR, TRAIN_ANNOTATION_PATH, transforms=get_transform(train=True))
dataset_val = SVHNDataset(VAL_IMG_DIR, VAL_ANNOTATION_PATH, transforms=get_transform(train=False))
dataset_test = SVHNDataset(TEST_IMG_DIR, TEST_ANNOTATION_PATH, transforms=get_transform(train=False), is_test=True)

print("Creating DataLoaders...")
data_loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=12, collate_fn=collate_fn, pin_memory=True)
data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=12, collate_fn=collate_fn, pin_memory=True)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=12, collate_fn=collate_fn, pin_memory=True)
print("DataLoaders created.")

# --- 4. 模型定義 ---
def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

print("Loading model...")
model = get_model(NUM_CLASSES)
model.to(DEVICE)
print("Model loaded.")

# 載入已儲存的模型權重
if os.path.exists(MODEL_SAVE_PATH):
    model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
    print(f"Model weights loaded from {MODEL_SAVE_PATH}")
else:
    print("Warning: Pretrained weights not found. Training from scratch.")

print("Model ready.")

# --- 5. 優化器與學習率調整器 ---
print("Setting up optimizer and scheduler...")
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
print("Optimizer and scheduler ready.")

# --- 6. 訓練與驗證 ---
print(f"Starting training for {NUM_EPOCHS} epochs...")
start_time_train = time.time()

for epoch in range(NUM_EPOCHS):
    start_time_epoch = time.time()

    # --- Training Phase ---
    model.train()
    train_loss = 0
    train_batch_count = 0
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} --- Training ---")
    optimizer.zero_grad()

    for i, batch_data in enumerate(data_loader_train):


        if batch_data is None or batch_data[0] is None or batch_data[1] is None:
            print(f"Warning: Skipping empty or invalid batch at training iteration {i}.")
            continue

        images, targets = batch_data


        if not images or not targets:
             print(f"Warning: Skipping training batch {i} due to empty images or targets after collate_fn.")
             continue


        try:
            images = list(image.to(DEVICE) for image in images)
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        except Exception as e:
            print(f"Error moving batch {i} data to device {DEVICE}: {e}")
            continue



        try:
            # 前向傳播
            loss_dict = model(images, targets)

            # 計算總損失
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.item() # .item() 將 tensor 轉為 python float

            # 檢查損失是否有效
            if not np.isfinite(loss_value):
                 print(f"\nWarning: Non-finite loss detected during training: {loss_value}. Skipping batch {i}.")
                 print("Loss dict:", {k: v.item() if torch.is_tensor(v) else v for k, v in loss_dict.items()})
                 # 清除梯度並跳過優化步驟
                 optimizer.zero_grad()
                 continue

            train_loss += loss_value
            train_batch_count += 1

            # 反向傳播與優化
            losses.backward()

            optimizer.step()
            optimizer.zero_grad()

            if (i + 1) % PRINT_FREQ == 0 or i == len(data_loader_train) - 1:
                current_avg_loss = train_loss / train_batch_count if train_batch_count > 0 else 0
                print(f"Batch [{i+1}/{len(data_loader_train)}], Current Avg Loss: {current_avg_loss:.4f} (Last Batch Loss: {loss_value:.4f})")

        except Exception as e:
            print(f"\nError during training batch {i}: {e}")
            print("Image IDs in this batch:", [t['image_id'].item() for t in targets if 'image_id' in t])
            optimizer.zero_grad()
            continue


    # 計算 Epoch 平均訓練損失
    avg_train_loss = train_loss / train_batch_count if train_batch_count > 0 else 0

    # --- Validation Phase ---
    val_loss = 0
    val_batch_count = 0
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} --- Validation ---")

    with torch.no_grad():
        model.train()
        for i, batch_data in enumerate(data_loader_val):
            if batch_data is None or batch_data[0] is None or batch_data[1] is None:
                continue

            images, targets = batch_data
            #print(f"[Debug] Batch {i}: type(targets)={type(targets)}, type(targets[0])={type(targets[0])}")
            # 若 targets 是 list 且裡面還是 list，解包
            while isinstance(targets, list) and len(targets) > 0 and isinstance(targets[0], list):
                targets = targets[0]

            if not images or not targets or not isinstance(targets[0], dict):
                print(f"=Skipping batch {i}: targets are not valid.")
                continue


            if not images or not targets:
                continue

            try:
                images = [img.to(DEVICE) for img in images]
                targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            except Exception as e:
                print(f"Error moving validation batch {i} data to device {DEVICE}: {e}")
                continue

            try:
                loss_dict = model(images, targets)

                if not isinstance(loss_dict, dict):
                    raise TypeError(f"Expected dict, got {type(loss_dict)}")

                losses = sum(loss for loss in loss_dict.values())
                loss_value = losses.item()

                if not np.isfinite(loss_value):
                    print(f"\nWarning: Non-finite validation loss detected: {loss_value} at batch {i}. Skipping.")
                    continue

                val_loss += loss_value
                val_batch_count += 1

            except Exception as e:
                print(f"\nError during validation batch {i}: {e}")
                print("Image IDs in this batch:", [t['image_id'].item() for t in targets if 'image_id' in t])
                continue


    # 計算 Epoch 平均驗證損失
    # 避免 val_batch_count 為 0 的情況
    avg_val_loss = val_loss / val_batch_count if val_batch_count > 0 else float('inf')

    # --- Epoch Summary & Model Saving ---
    end_time_epoch = time.time()
    elapsed_time = end_time_epoch - start_time_epoch
    print(f"\n--- Epoch {epoch+1} Summary ---")
    print(f"Average Training Loss: {avg_train_loss:.4f}")

    if val_batch_count > 0:
        print(f"Average Validation Loss: {avg_val_loss:.4f}")
    else:
        print("Validation Loss: N/A (No valid validation batches processed)")
    print(f"Time Elapsed: {elapsed_time:.2f} seconds")
    print(f"Current Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

    # 只有在驗證損失有效且改善時才保存模型
    if val_batch_count > 0 and avg_val_loss < BEST_VAL_LOSS:
        BEST_VAL_LOSS = avg_val_loss
        try:
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"Validation loss improved to {avg_val_loss:.4f}. Model saved to {MODEL_SAVE_PATH}")
        except Exception as e:
            print(f"Error saving model: {e}")
    elif val_batch_count > 0:
        print(f"Validation loss ({avg_val_loss:.4f}) did not improve from best ({BEST_VAL_LOSS:.4f}).")

    print("-" * 30)

    # --- 更新學習率 ---

    lr_scheduler.step(avg_val_loss)

    #下載目前 epoch model（僅在模型檔存在時下載）
    '''
    from google.colab import files
    if os.path.exists(MODEL_SAVE_PATH):
        files.download(MODEL_SAVE_PATH)
    else:
        print(f"Warning: Skipped download. Model file {MODEL_SAVE_PATH} does not exist.")
    '''

# --- 訓練結束 ---
total_training_time = time.time() - start_time_train
print(f"\n--- Training Finished ---")
print(f"Total Training Time: {total_training_time:.2f} seconds ({total_training_time/60:.2f} minutes)")
print(f"Best Validation Loss Achieved: {BEST_VAL_LOSS:.4f}")
# -------------------------------------------------------------

# --- 載入最佳模型進行最終預測 ---
print("\n--- Loading best model for final prediction ---")
# 檢查最佳模型文件是否存在
if os.path.exists(MODEL_SAVE_PATH):
    try:
        model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        print(f"Loaded best model from {MODEL_SAVE_PATH} with validation loss {BEST_VAL_LOSS:.4f}")
    except Exception as e:
        print(f"Warning: Could not load best model from {MODEL_SAVE_PATH}: {e}. Using model from the last epoch.")
else:
    print(f"Warning: Best model file {MODEL_SAVE_PATH} not found. Using model from the last epoch.")

model.eval()

# --- 7. 生成 Task 1 預測 (使用測試集) ---
print(f"\n--- Generating Task 1 predictions ({TASK1_OUTPUT_PATH}) ---")
task1_results = []
with torch.no_grad():
    for i, batch_data in enumerate(data_loader_test):
        if batch_data is None or batch_data[0] is None: continue
        images, targets = batch_data
        images = list(img.to(DEVICE) for img in images)
        if not targets or targets[0] is None or 'image_id' not in targets[0]: continue
        original_image_id = targets[0]['image_id'].item()

        if (i+1) % 1000 == 0:
             print(f"Task 1: Processing test image {i+1}/{len(data_loader_test)}")

        outputs = model(images)

        for output in outputs:
            boxes = output['boxes'].cpu().numpy()
            labels = output['labels'].cpu().numpy() # category_id (1-10)
            scores = output['scores'].cpu().numpy()

            for box, label, score in zip(boxes, labels, scores):
                xmin, ymin, xmax, ymax = box
                width = xmax - xmin
                height = ymax - ymin
                if width <= 0 or height <= 0: continue

                prediction = {
                    "image_id": original_image_id,
                    "bbox": [float(xmin), float(ymin), float(width), float(height)],
                    "score": float(score),
                    "category_id": int(label) # Class prediction
                }
                task1_results.append(prediction)

# --- 8. 保存 Task 1 預測結果 ---
print(f"\nSaving Task 1 predictions to {TASK1_OUTPUT_PATH}...")
with open(TASK1_OUTPUT_PATH, 'w') as f:
    json.dump(task1_results, f, indent=4)
print("Task 1 prediction file saved successfully.")


# --- 9. 生成 Task 2 預測 (後處理 Task 1 結果) ---
print(f"\n--- Generating Task 2 predictions ({TASK2_OUTPUT_PATH}) ---")

# Category ID 到實際數字字元的映射
category_id_to_digit = {
    1: '0', 2: '1', 3: '2', 4: '3', 5: '4',
    6: '5', 7: '6', 8: '7', 9: '8', 10: '9'
}

# 按 image_id 分組 Task 1 的結果
detections_by_image = defaultdict(list)
for det in task1_results:
    detections_by_image[det['image_id']].append(det)

task2_results = {} # 儲存 Task 2 結果 {image_id: number_string}

print(f"Processing Task 1 results for Task 2 (Score Threshold: {TASK2_SCORE_THRESHOLD})...")
processed_images = 0
for image_id, detections in detections_by_image.items():
    # 1. 過濾低信心結果
    high_conf_detections = [d for d in detections if d['score'] >= TASK2_SCORE_THRESHOLD]

    if not high_conf_detections:
        task2_results[image_id] = "-1" # 如果沒有高信度結果，則-1
        continue

    # 2. 根據 BBox 的 x 座標排序 (使用 BBox 中心點的 x 座標)
    # bbox: [xmin, ymin, width, height]
    # center_x = xmin + width / 2
    sorted_detections = sorted(
        high_conf_detections,
        key=lambda d: d['bbox'][0] + d['bbox'][2] / 2
    )

    # 3. 轉換 category_id 並拼接字串
    number_sequence = ""
    for d in sorted_detections:
        digit = category_id_to_digit.get(d['category_id'])
        if digit is not None:
            number_sequence += digit
        else:
            print(f"Warning: Unknown category_id {d['category_id']} found for image {image_id}. Skipping.")

    task2_results[image_id] = number_sequence
    processed_images += 1
    if processed_images % 1000 == 0:
         print(f"Task 2: Processed detections for {processed_images} images...")

print(f"Task 2 processing finished for {processed_images} images.")
with open(TEST_ANNOTATION_PATH) as f:
    test_json = json.load(f)
all_image_ids = set(img['id'] for img in test_json['images'])
missing_ids = all_image_ids - set(task2_results.keys())
for missing_id in missing_ids:
    task2_results[missing_id] = "-1"
    print(f"Added missing image_id {missing_id} with prediction -1")

print(f"Final total task2 results: {len(task2_results)} (should be {len(all_image_ids)})")

# --- 10. 保存 Task 2 預測結果 ---
print(f"\nSaving Task 2 predictions to {TASK2_OUTPUT_PATH}...")

import csv

with open('pred.csv', 'w', newline='', encoding='utf-8-sig') as f:
    writer = csv.writer(f)
    writer.writerow(['image_id', 'pred_label'])
    for image_id in sorted(task2_results.keys()):
        writer.writerow([image_id, str(task2_results[image_id])])

print("Task 2 prediction file saved successfully.")
print("Script finished.")


# --- (可選) 下載預測檔案 ---
from google.colab import files
files.download(TASK1_OUTPUT_PATH)
files.download(TASK2_OUTPUT_PATH)
files.download(MODEL_SAVE_PATH)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
Using torchvision v2 transforms: True
Creating Datasets...
loading annotations into memory...
Done (t=0.38s)
creating index...
index created!
Loaded 30062 images from data/your_dataset/train.json. Is test set: False
loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
Loaded 3340 images from data/your_dataset/valid.json. Is test set: False
loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
Loaded 13068 images from data/your_dataset/test.json. Is test set: True
Creating DataLoaders...
DataLoaders created.
Loading model...
Model loaded.
Model weights loaded from fasterrcnn_svhn_best_model.pth
Model ready.
Setting up optimizer and scheduler...
Optimizer and scheduler ready.
Starting training for 10 epochs...

--- Epoch 1/10 --- Training ---
Batch [100/3758], Current Avg Loss: 0.1620

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>