In [3]:
import os
import sys
import torch
import random
import numpy as np
import csv
from PIL import Image
from torch import nn
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# 設定隨機種子以確保結果可重現
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

# 定義圖片的目標大小，所有圖片將縮放到此大小
IMAGE_SIZE = (128, 32)  # 根據驗證碼圖片調整大小

# 定義圖片所在的資料夾
IMAGE_DIR = "./archive_3"    # 原始圖片資料夾
NEW_IMAGE_DIR = "./img-new"  # 新增圖片資料夾

# 定義資料保存的資料夾
DATA_SAVE_DIR = "./data"    # 所有 .pt 文件和模型將保存於此

# 定義字母數字列表
ALPHA_NUMS = "abcdefghijklmnopqrstuvwxyz0123456789"
NUM_CLASSES = len(ALPHA_NUMS) + 1  # 加1是為了CTC的blank字符

# 定義驗證碼的位數範圍
MIN_DIGITS = 4
MAX_DIGITS = 6

# 檢查設備是否有GPU可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")

# 建立字符與索引的映射
char_to_idx = {c: i + 1 for i, c in enumerate(ALPHA_NUMS)}  # 從1開始，0留給blank
idx_to_char = {i + 1: c for i, c in enumerate(ALPHA_NUMS)}
idx_to_char[0] = ''  # blank字符

def image_to_tensor(img):
    """
    將圖片轉換為Tensor
    :param img: PIL Image物件
    :return: 經過處理的Tensor
    """
    in_img = img.resize(IMAGE_SIZE)
    in_img = in_img.convert("L")  # 轉換為灰階圖
    arr = np.array(in_img)
    t = torch.from_numpy(arr).float()
    t = t.unsqueeze(0)  # 增加通道維度
    t = t / 255.0  # 進行歸一化
    return t

def code_to_indices(code):
    """
    將驗證碼字串轉換為索引列表
    :param code: 驗證碼字串
    :return: 索引列表
    """
    return [char_to_idx[c] for c in code.lower() if c in char_to_idx]

def prepare_data(image_dir, save_prefix):
    """
    準備訓練資料
    :param image_dir: 圖片所在的資料夾
    :param save_prefix: 保存資料的前綴（例如 'new'）
    :return: (train_path, valid_path, test_path) 或 None
    """
    # 創建保存資料的資料夾
    if not os.path.isdir(DATA_SAVE_DIR):
        os.makedirs(DATA_SAVE_DIR)
        print(f"建立資料保存資料夾: {DATA_SAVE_DIR}")

    print(f"掃描圖片資料夾: {image_dir}")
    image_paths = []
    labels = []
    for root, dirs, files in os.walk(image_dir):
        for filename in files:
            if not (filename.lower().endswith(".jpg") or filename.lower().endswith(".png")):
                continue
            path = os.path.join(root, filename)
            code = filename.split(".")[0]
            if MIN_DIGITS <= len(code) <= MAX_DIGITS:
                if all(c in ALPHA_NUMS for c in code.lower()):
                    image_paths.append(path)
                    labels.append(code)
                else:
                    print(f"跳過包含無效字符的檔案: {filename}")

    print(f"找到 {len(image_paths)} 張符合條件的圖片。")

    # 檢查是否有資料
    if len(image_paths) == 0:
        print("未找到符合條件的圖片。請檢查圖片資料夾和命名格式。")
        return None

    # 打亂資料
    data = list(zip(image_paths, labels))
    random.shuffle(data)

    # 划分資料集
    total_samples = len(data)
    train_size = int(0.8 * total_samples)
    valid_size = int(0.1 * total_samples)
    test_size = total_samples - train_size - valid_size

    train_data = data[:train_size]
    valid_data = data[train_size:train_size + valid_size]
    test_data = data[train_size + valid_size:]

    print(f"資料集划分: 訓練集 {len(train_data)}，驗證集 {len(valid_data)}，測試集 {len(test_data)}")

    # 保存資料
    train_path = os.path.join(DATA_SAVE_DIR, f"train_data_{save_prefix}.pt")
    valid_path = os.path.join(DATA_SAVE_DIR, f"valid_data_{save_prefix}.pt")
    test_path = os.path.join(DATA_SAVE_DIR, f"test_data_{save_prefix}.pt")

    torch.save(train_data, train_path)
    torch.save(valid_data, valid_path)
    torch.save(test_data, test_path)

    print(f"資料已保存為: {train_path}, {valid_path}, {test_path}")
    return train_path, valid_path, test_path

class CaptchaDataset(torch.utils.data.Dataset):
    """
    驗證碼資料集
    """
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, code = self.data[idx]
        try:
            with Image.open(path) as img:
                image = image_to_tensor(img)
        except Exception as e:
            print(f"無法打開圖片 {path}: {e}")
            # 跳過這個樣本，重新隨機選擇一個樣本
            return self.__getitem__(random.randint(0, len(self.data) - 1))
        label = code_to_indices(code)
        return image, torch.tensor(label, dtype=torch.long)

def collate_fn(batch):
    """
    用於DataLoader的collate_fn，處理不同長度的序列
    :param batch: 一批資料
    :return: 圖片Tensor，標籤Tensor，標籤長度Tensor
    """
    images, labels = zip(*batch)
    images = torch.stack(images)

    label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
    labels = torch.cat(labels)

    return images, labels, label_lengths

class CRNN(nn.Module):
    """
    CNN + LSTM 模型
    """
    def __init__(self, num_classes):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            # 第一個卷積區塊
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),

            # 第二個卷積區塊
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),

            # 第三個卷積區塊
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 第四個卷積區塊
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d((2, 1), (2, 1)),

            # 第五個卷積區塊
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # 第六個卷積區塊
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d((2, 1), (2, 1)),

            # 第七個卷積區塊
            nn.Conv2d(512, 512, kernel_size=(2, 2), stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 輸出高度現在為1
        )

        # 定義LSTM和全連接層
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=256,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
        )
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        conv = self.cnn(x)  # [batch, channels, height, width]
        b, c, h, w = conv.size()
        assert h == 1, "卷積輸出的高度必須為1"
        conv = conv.squeeze(2)  # [batch, channels, width]
        conv = conv.permute(0, 2, 1)  # [batch, width, channels]

        recurrent, _ = self.lstm(conv)
        output = self.fc(recurrent)
        output = output.permute(1, 0, 2)  # [width, batch, num_classes]
        return output  # CTC Loss需要輸入形狀為 (seq_len, batch_size, num_classes)

def decode_predictions(preds):
    """
    將模型預測的輸出轉換為字串
    :param preds: 模型預測結果
    :return: 預測的字串列表
    """
    preds = preds.permute(1, 0, 2)  # [batch, seq_len, num_classes]
    preds = torch.argmax(preds, dim=2)  # [batch, seq_len]
    preds = preds.cpu().numpy()

    decoded_strings = []
    for pred in preds:
        chars = []
        prev_char_idx = None
        for idx in pred:
            if idx != prev_char_idx and idx != 0:
                chars.append(idx_to_char.get(idx, ''))
            prev_char_idx = idx
        decoded_strings.append(''.join(chars))
    return decoded_strings

def calculate_accuracy(model, data_loader):
    """
    計算模型在資料集上的準確率
    :param model: 訓練好的模型
    :param data_loader: 資料加載器
    :return: 準確率
    """
    model.eval()
    total_count = 0
    correct_count = 0

    with torch.no_grad():
        for images, labels, label_lengths in data_loader:
            images = images.to(device)
            outputs = model(images)  # [seq_len, batch, num_classes]
            outputs = outputs.log_softmax(2)
            preds = outputs.detach().cpu()
            pred_strings = decode_predictions(preds)

            labels = labels.cpu().numpy()
            label_lengths = label_lengths.cpu().numpy()

            batch_size = images.size(0)
            total_count += batch_size

            label_strings = []
            index = 0
            for length in label_lengths:
                label = labels[index:index + length]
                label_str = ''.join([idx_to_char.get(idx, '') for idx in label])
                label_strings.append(label_str)
                index += length

            for pred_str, label_str in zip(pred_strings, label_strings):
                if pred_str == label_str:
                    correct_count += 1

    accuracy = correct_count / total_count
    return accuracy

def evaluate_model(save_model_path, test_data_path):
    """
    評估模型並生成混淆矩陣
    :param save_model_path: 模型檔案路徑
    :param test_data_path: 測試資料集路徑
    """
    # 載入測試資料
    if not os.path.exists(test_data_path):
        print(f"測試資料檔案 '{test_data_path}' 不存在。請先準備資料。")
        return

    test_data = torch.load(test_data_path)
    test_dataset = CaptchaDataset(test_data)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn
    )

    # 載入模型
    model = CRNN(num_classes=NUM_CLASSES).to(device)
    try:
        model.load_state_dict(torch.load(save_model_path, map_location=device))
        print(f"已載入模型: {save_model_path}")
    except FileNotFoundError:
        print(f"模型檔案 '{save_model_path}' 未找到。請先訓練模型。")
        return
    model.eval()

    total_sequences = 0  # 總序列數
    correct_sequences = 0  # 完全正確的序列數

    total_chars = 0  # 總字符數
    correct_chars = 0  # 正確字符數

    all_true_chars = []
    all_pred_chars = []

    with torch.no_grad():
        for images, labels, label_lengths in test_loader:
            images = images.to(device)
            outputs = model(images)  # [seq_len, batch, num_classes]
            outputs = outputs.log_softmax(2)
            preds = outputs.detach().cpu()
            pred_strings = decode_predictions(preds)

            labels = labels.cpu().numpy()
            label_lengths = label_lengths.cpu().numpy()

            batch_size = images.size(0)
            total_sequences += batch_size  # 增加序列計數

            label_strings = []
            index = 0
            for length in label_lengths:
                label = labels[index:index + length]
                label_str = ''.join([idx_to_char.get(idx, '') for idx in label])
                label_strings.append(label_str)
                total_chars += length  # 增加字符計數
                index += length

            for pred_str, label_str in zip(pred_strings, label_strings):
                if pred_str == label_str:
                    correct_sequences += 1  # 完全正確的序列計數

                # 逐字符比較，計算字符準確率
                min_len = min(len(pred_str), len(label_str))
                for i in range(min_len):
                    if pred_str[i] == label_str[i]:
                        correct_chars += 1  # 增加正確字符計數
                    all_true_chars.append(label_str[i])
                    all_pred_chars.append(pred_str[i])

    # 計算準確率
    char_accuracy = correct_chars / total_chars if total_chars > 0 else 0
    seq_accuracy = correct_sequences / total_sequences if total_sequences > 0 else 0

    print(f"字符級準確率: {char_accuracy * 100:.2f}% (正確: {correct_chars}/{total_chars})")
    print(f"序列級準確率: {seq_accuracy * 100:.2f}% (正確: {correct_sequences}/{total_sequences})")

    # 生成混淆矩陣並轉換為百分比
    cm = confusion_matrix(all_true_chars, all_pred_chars, labels=list(ALPHA_NUMS))
    cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100  # 將每行轉換為百分比

    # 繪製混淆矩陣
    plt.figure(figsize=(20, 14))  # 增大圖形大小
    sns.heatmap(
        cm_percentage,
        annot=True,
        fmt='.2f',
        cmap='Blues',
        xticklabels=list(ALPHA_NUMS),
        yticklabels=list(ALPHA_NUMS),
        annot_kws={"size": 6},  # 調整註釋字體大小
        cbar_kws={"shrink": 0.8}  # 縮小顏色條
    )
    plt.xlabel('Predicted Characters ', fontsize=16)  # 增大字體
    plt.ylabel('True Characters', fontsize=16)
    plt.title('confusion matrix(pecentage)', fontsize=18)
    plt.xticks(fontsize=14, rotation=0)  # 旋轉x軸標籤並調整字體大小
    plt.yticks(fontsize=14, rotation=0)   # 旋轉y軸標籤並調整字體大小
    plt.tight_layout()
    cm_filename = os.path.join(DATA_SAVE_DIR, f"confusion_matrix_percentage_{os.path.splitext(os.path.basename(save_model_path))[0]}.png")
    plt.show()
    print(f"混淆矩陣已保存為 {cm_filename}")

def train_model(train_loader, valid_loader, save_model_path, num_epochs=50, pretrained_model_path=None):
    """
    通用訓練函數
    :param train_loader: 訓練資料加載器
    :param valid_loader: 驗證資料加載器
    :param save_model_path: 模型保存路徑
    :param num_epochs: 訓練輪數
    :param pretrained_model_path: 預訓練模型路徑（可選）
    """
    # 建立模型
    model = CRNN(num_classes=NUM_CLASSES).to(device)

    # 如果提供了預訓練模型路徑，且檔案存在，則載入模型權重
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
        print(f"載入預訓練模型: {pretrained_model_path}")
    else:
        print("從頭開始訓練模型。")

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, min_lr=1e-9)
    criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

    best_valid_accuracy = 0.0  # 初始最佳驗證準確率
    patience = 6  # Early Stopping的耐心值，即容忍多少個epoch沒有提升
    patience_counter = 0  # 記錄驗證準確率沒有改善的epoch次數

    # 記錄損失和準確率
    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []

    # 打開CSV檔案以寫入模式
    csv_filename = os.path.join(DATA_SAVE_DIR, f'train_val_results_{os.path.splitext(os.path.basename(save_model_path))[0]}.csv')
    with open(csv_filename, 'w', newline='', encoding='utf-8') as csvfile:
        csvwriter = csv.writer(csvfile)
        # 寫入表頭
        csvwriter.writerow(['Epoch', 'train loss', 'val loss', 'train accu', 'val accu'])

        for epoch in range(1, num_epochs + 1):
            model.train()
            total_loss = 0.0
            for images, labels, label_lengths in train_loader:
                images = images.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                outputs = model(images)  # [seq_len, batch, num_classes]

                input_lengths = torch.full(
                    size=(images.size(0),),
                    fill_value=outputs.size(0),
                    dtype=torch.long
                ).to(device)
                loss = criterion(outputs, labels, input_lengths, label_lengths)

                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            avg_train_loss = total_loss / len(train_loader)

            # 計算訓練集準確率
            train_accuracy = calculate_accuracy(model, train_loader)
            # 計算驗證集損失和準確率
            model.eval()
            total_loss = 0.0
            with torch.no_grad():
                for images, labels, label_lengths in valid_loader:
                    images = images.to(device)
                    labels = labels.to(device)

                    outputs = model(images)
                    input_lengths = torch.full(
                        size=(images.size(0),),
                        fill_value=outputs.size(0),
                        dtype=torch.long
                    ).to(device)
                    loss = criterion(outputs, labels, input_lengths, label_lengths)

                    total_loss += loss.item()

            avg_valid_loss = total_loss / len(valid_loader)
            valid_accuracy = calculate_accuracy(model, valid_loader)

            # 記錄每個epoch的損失和準確率
            train_losses.append(avg_train_loss)
            valid_losses.append(avg_valid_loss)
            train_accuracies.append(train_accuracy)
            valid_accuracies.append(valid_accuracy)

            # 將結果寫入CSV檔案
            csvwriter.writerow([epoch, avg_train_loss, avg_valid_loss, train_accuracy, valid_accuracy])

            print(f"Epoch [{epoch}/{num_epochs}]")
            print(f"train loss: {avg_train_loss:.4f}, train accu: {train_accuracy * 100:.2f}%")
            print(f"val loss: {avg_valid_loss:.4f}, val accu: {valid_accuracy * 100:.2f}%")
            scheduler.step(avg_valid_loss)

            # Early Stopping檢查（基於驗證集準確率）
            if valid_accuracy > best_valid_accuracy:
                best_valid_accuracy = valid_accuracy
                patience_counter = 0  # 重置耐心計數器
                # 確保保存目錄存在
                os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
                torch.save(model.state_dict(), save_model_path)
                print("驗證準確率提升，模型已保存。")
            else:
                patience_counter += 1
                print(f"驗證準確率未提升，耐心計數器: {patience_counter}/{patience}")

            if patience_counter >= patience:
                print("因為驗證準確率長期未提升，提前停止訓練。")
                break  # 提前停止訓練

    # 繪製訓練和驗證的損失曲線
    plt.figure(figsize=(12, 5))

    # 準確率曲線
    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(train_accuracies) + 1), [acc * 100 for acc in train_accuracies], label='train accu', color='green')
    plt.plot(range(1, len(valid_accuracies) + 1), [acc * 100 for acc in valid_accuracies], label='val accu', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('tain and val accu')
    plt.legend()
    plt.grid(True)

    # 損失曲線
    plt.subplot(1, 2, 1)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='train loss', color='blue')
    plt.plot(range(1, len(valid_losses) + 1), valid_losses, label='val loss', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('train and val loss')
    plt.legend()
    plt.grid(True)



    plt.tight_layout()

    # 顯示圖像
    plt.show()
def predict():
    """
    使用訓練好的模型對單張圖片進行預測
    """
    # 可用的模型列表
    # 將所有crnn_model*.pt文件視為模型文件
    available_models = [f for f in os.listdir(DATA_SAVE_DIR) if f.startswith("crnn_model") and f.endswith(".pt")]
    available_models = [os.path.join(DATA_SAVE_DIR, model) for model in available_models]

    if not available_models:
        print("沒有可用的模型。請先訓練模型。")
        return

    # 顯示可用的模型
    print("可用的模型:")
    for idx, path in enumerate(available_models, 1):
        print(f"{idx}. {path}")
    try:
        model_choice = int(input(f"請選擇要使用的模型編號（1-{len(available_models)}）： ").strip()) - 1
        if model_choice < 0 or model_choice >= len(available_models):
            raise ValueError
        selected_model_path = available_models[model_choice]
    except (ValueError, IndexError):
        print("無效的選擇。")
        return

    # 載入模型
    model = CRNN(num_classes=NUM_CLASSES).to(device)
    try:
        model.load_state_dict(torch.load(selected_model_path, map_location=device))
        print(f"已載入模型: {selected_model_path}")
    except FileNotFoundError:
        print(f"模型檔案 '{selected_model_path}' 未找到。")
        return
    model.eval()

    while True:
        try:
            # 輸入圖片路徑
            image_path = input("請輸入圖片路徑（或按 Enter 退出）： ").strip()
            if not image_path:
                print("退出預測模式。")
                break
            if not os.path.isfile(image_path):
                print(f"檔案 '{image_path}' 不存在，請重新輸入。")
                continue
            with Image.open(image_path) as img:
                tensor_in = image_to_tensor(img).to(device).unsqueeze(0)  # 增加批次維度

            # 預測輸出
            tensor_out = model(tensor_in)  # [seq_len, batch, num_classes]
            preds = tensor_out.log_softmax(2)
            pred_strings = decode_predictions(preds)

            code = pred_strings[0]
            print(f"預測驗證碼: {code}\n")
        except Exception as e:
            print("發生錯誤:", e)

def main():
    """
    主函數，處理不同的操作模式
    """
    operation = input("請輸入模式 (prepare|train|evaluate|predict): ").strip().lower()

    if operation == "prepare":
        prepare_result = prepare_data(IMAGE_DIR, "original")
        if prepare_result:
            print("資料準備完成。")
    elif operation == "train":
        # 選擇是否使用預訓練模型
        use_pretrained = input("是否使用預訓練模型進行訓練？ (y/n)： ").strip().lower()
        if use_pretrained == 'y':
            # 使用預訓練模型，資料來源為 img-new
            data_source = "new"
            image_dir = NEW_IMAGE_DIR
            print("使用預訓練模型，資料來源設為 img-new。")
        else:
            # 從頭訓練，資料來源為 archive_3
            data_source = "original"
            image_dir = IMAGE_DIR
            print("從頭開始訓練模型，資料來源設為 archive_3。")

        # 準備資料
        prepare_result = prepare_data(image_dir, data_source)
        if prepare_result is None:
            print("資料準備失敗，訓練終止。")
            return

        train_path, valid_path, _ = prepare_result  # 測試集路徑在訓練中不需要

        # 創建資料集和資料加載器
        train_data = torch.load(train_path)
        valid_data = torch.load(valid_path)

        train_dataset = CaptchaDataset(train_data)
        valid_dataset = CaptchaDataset(valid_data)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn
        )

        # 如果使用預訓練模型，選擇模型文件
        if use_pretrained == 'y':
            # 列出所有模型文件
            model_files = [f for f in os.listdir(DATA_SAVE_DIR) if f.startswith("crnn_model") and f.endswith(".pt")]
            if not model_files:
                print("沒有找到任何預訓練模型文件。請先訓練模型。")
                return
            # 顯示可用的模型
            print("可用的預訓練模型:")
            for idx, file in enumerate(model_files, 1):
                print(f"{idx}. {file}")
            try:
                model_choice = int(input(f"請選擇要載入的預訓練模型編號（1-{len(model_files)}）： ").strip()) - 1
                if model_choice < 0 or model_choice >= len(model_files):
                    raise ValueError
                selected_model_file = model_files[model_choice]
                pretrained_model_path = os.path.join(DATA_SAVE_DIR, selected_model_file)
            except (ValueError, IndexError):
                print("無效的選擇。")
                return
        else:
            pretrained_model_path = None

        # 定義模型保存路徑
        if data_source == "original":
            model_save_filename = "crnn_model_original.pt"
        else:
            model_save_filename = "crnn_model_new.pt"
        save_model_path = os.path.join(DATA_SAVE_DIR, model_save_filename)

        # 訓練模型
        train_model(
            train_loader,
            valid_loader,
            save_model_path=save_model_path,
            num_epochs=50,
            pretrained_model_path=pretrained_model_path
        )
        print(f"模型已保存為 '{save_model_path}'")
    elif operation == "evaluate":
        # 列出所有模型文件
        model_files = [f for f in os.listdir(DATA_SAVE_DIR) if f.startswith("crnn_model") and f.endswith(".pt")]
        if not model_files:
            print("沒有找到任何模型文件。請先訓練模型。")
            return
        # 顯示可用的模型
        print("可用的模型:")
        for idx, file in enumerate(model_files, 1):
            print(f"{idx}. {file}")
        try:
            model_choice = int(input(f"請選擇要使用的模型編號（1-{len(model_files)}）： ").strip()) - 1
            if model_choice < 0 or model_choice >= len(model_files):
                raise ValueError
            selected_model_file = model_files[model_choice]
            selected_model_path = os.path.join(DATA_SAVE_DIR, selected_model_file)
        except (ValueError, IndexError):
            print("無效的選擇。")
            return

        # 列出所有 test_data*.pt 文件，包括 test_data.pt
        test_data_files = [f for f in os.listdir(DATA_SAVE_DIR) if f.startswith("test_data") and f.endswith(".pt")]
        if not test_data_files:
            print("沒有找到 test_data*.pt 檔案。請先準備資料。")
            return

        # 顯示可用的測試資料集
        print("可用的測試資料集:")
        for idx, file in enumerate(test_data_files, 1):
            print(f"{idx}. {file}")
        try:
            test_choice = int(input(f"請選擇要使用的測試資料集編號（1-{len(test_data_files)}）： ").strip()) - 1
            if test_choice < 0 or test_choice >= len(test_data_files):
                raise ValueError
            selected_test_data_file = test_data_files[test_choice]
            selected_test_data_path = os.path.join(DATA_SAVE_DIR, selected_test_data_file)
        except (ValueError, IndexError):
            print("無效的選擇。")
            return

        evaluate_model(selected_model_path, selected_test_data_path)
    elif operation == "predict":
        predict()
    else:
        print(f"不支援的操作: {operation}")
        sys.exit(1)

if __name__ == "__main__":
    main()


使用設備: cuda
從頭開始訓練模型，資料來源設為 archive_3。
掃描圖片資料夾: ./archive_3
找到 113062 張符合條件的圖片。
資料集划分: 訓練集 90449，驗證集 11306，測試集 11307
資料已保存為: ./data\train_data_original.pt, ./data\valid_data_original.pt, ./data\test_data_original.pt


  train_data = torch.load(train_path)
  valid_data = torch.load(valid_path)


從頭開始訓練模型。
Epoch [1/50]
train loss: 3.8346, train accu: 0.00%
val loss: 3.6462, val accu: 0.00%
驗證準確率未提升，耐心計數器: 1/6


KeyboardInterrupt: 