<a href="https://colab.research.google.com/github/tien2204/Finetue-DNN-and-Transformers-Model-for-plant-disease-classification/blob/main/Vit_and_E0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# file: train_evaluate.py

# ==============================================================================
# MODULE 1: IMPORTS VÀ CÀI ĐẶT BAN ĐẦU
# ==============================================================================
import time
import os
import random
import traceback
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
# !!! Thay đổi import cho AMP !!!
import torch.amp # Import cấp cao hơn
# --------------------------------
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
from tqdm import tqdm

# --- Cài đặt cơ bản ---
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    print(f"CUDA available. Using {torch.cuda.device_count()} GPU(s).")
    # Thêm thông tin về GPU đang dùng (nếu có)
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA not available. Using CPU.")
use_gpu = torch.cuda.is_available()
# Xác định device type cho amp
amp_device_type = 'cuda' if use_gpu else 'cpu'


# --- Cấu hình Model và Training ---
models_to_test = ['efficientnet_b0', 'vit_b_16']
input_sizes = {
    'efficientnet_b0': (224, 224),
    'vit_b_16': (224, 224)
}
batch_size = 32
# --- Gợi ý tối ưu tốc độ: Thử tăng num_workers ---
num_workers_loader = 4 # <<< Thử tăng lên 4 hoặc 8
# --------------------------------------------------
accumulation_settings = {
    'efficientnet_b0': 1,
    'vit_b_16': 4
}
epochs_to_train = 10
PROCESSED_DATA_DIR = '/content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed' # <<< Đảm bảo đường dẫn đúng
num_classes = None

# ==============================================================================
# MODULE 2: ĐỊNH NGHĨA MODEL (ExtendedModel) - Giữ nguyên
# ==============================================================================
class ExtendedModel(nn.Module):
    def __init__(self, base_model, num_classes, input_size=(224,224)):
        super(ExtendedModel, self).__init__()
        self.base_model = base_model
        probe_device = torch.device("cuda" if use_gpu else "cpu")
        # Tạm chuyển base model sang device để probe output size
        self.base_model.to(probe_device)
        dummy = torch.zeros(1, 3, input_size[0], input_size[1]).to(probe_device)
        with torch.no_grad():
            dummy_out = self.base_model(dummy)
        in_features = dummy_out.shape[1]
        print(f"ModelSetup: Determined in_features for final layer: {in_features}")
        self.extension = nn.Linear(in_features, num_classes)
        # Có thể chuyển base_model về CPU ở đây nếu muốn, nhưng không bắt buộc
        # self.base_model.to('cpu')

    def forward(self, x):
        base_out = self.base_model(x)
        out = self.extension(base_out)
        return out

# ==============================================================================
# MODULE 3: HÀM HỖ TRỢ (Load model pre-trained) - Sửa lỗi Warning
# ==============================================================================
def load_pretrained_core_model(name):
    """Load model pre-trained từ torchvision, ưu tiên dùng weights API."""
    print(f"ModelSetup: Loading pre-trained core model: {name}")
    model = None
    try:
        # Thử tìm Weights enum tương ứng (ví dụ: EfficientNet_B0_Weights)
        weights_enum_name = f"{name.replace('_', '').upper()}_Weights" # Tạo tên enum chuẩn hơn
        weights_enum = getattr(models, weights_enum_name, None)

        if weights_enum:
            # Lấy weights mặc định (thường là tốt nhất)
            weights = weights_enum.DEFAULT
            print(f"Using weights API: {weights}")
            model = getattr(models, name)(weights=weights)
        else:
            # Nếu không có weights enum (model cũ hoặc tên không khớp), thử pretrained=True (sẽ có warning)
            print(f"Weights enum '{weights_enum_name}' not found. Trying legacy pretrained=True (may show warning).")
            model = getattr(models, name)(pretrained=True)

    except AttributeError as e:
        # Xử lý nếu getattr(models, name) thất bại
        print(f"[ERROR] Could not load model '{name}' using standard methods: {e}")
        print("Trying models.__dict__ fallback...")
        try:
            # Fallback rất cũ (ít dùng)
             model = models.__dict__[name](pretrained=True)
        except Exception as e2:
             print(f"[ERROR] Failed loading model '{name}' completely: {e2}")
             raise # Ném lỗi ra ngoài nếu không thể load
    except Exception as e:
        print(f"[ERROR] An unexpected error occurred loading model '{name}': {e}")
        traceback.print_exc()
        raise

    if model is None:
        raise ValueError(f"Could not load model for name: {name}")
    return model


# ==============================================================================
# MODULE 4: HÀM TẢI DỮ LIỆU (load_preprocessed_data) - Giữ nguyên logic
# ==============================================================================
def load_preprocessed_data(processed_data_dir, batch_size, num_workers=4, pin_memory=False):
    """Tải dữ liệu ĐÃ TIỀN XỬ LÝ."""
    print(f"\nLoading preprocessed data from: {processed_data_dir}")
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    data_transforms_online = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        'val_test': transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
    }
    image_datasets = {}
    for x in ['train', 'val', 'test']:
        split_path = os.path.join(processed_data_dir, x)
        if not os.path.isdir(split_path): raise FileNotFoundError(f"Dir not found: {split_path}")
        try:
             transform_key = 'train' if x == 'train' else 'val_test'
             image_datasets[x] = datasets.ImageFolder(split_path, data_transforms_online[transform_key])
             print(f"Loaded {len(image_datasets[x])} images from {split_path}")
             if len(image_datasets[x]) == 0: print(f"[Warning] No images loaded for split '{x}'.")
        except Exception as e: print(f"[ERROR] Load failed for '{x}': {e}"); raise
    dset_classes = image_datasets['train'].classes
    print(f"Classes loaded ({len(dset_classes)}): {dset_classes}")
    dataloaders = {
        x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=(x == 'train'),
                      num_workers=num_workers, pin_memory=pin_memory)
        for x in ['train', 'val', 'test']
    }
    print("Created DataLoaders.")
    return dataloaders['train'], dataloaders['val'], dataloaders['test'], dset_classes

# ==============================================================================
# MODULE 5: HÀM ĐÁNH GIÁ MODEL (evaluate_model) - Sửa lỗi Warning AMP
# ==============================================================================
def evaluate_model(net, dataloader, criterion):
    """Đánh giá model trên dataloader."""
    net.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    device = next(net.parameters()).device
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating", leave=False, unit="batch")
        for inputs, labels in progress_bar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            # !!! Sửa lỗi Warning: Dùng torch.amp.autocast !!!
            with torch.amp.autocast(device_type=amp_device_type, enabled=use_gpu):
                outputs = net(inputs)
                if criterion:
                    loss = criterion(outputs, labels)
                    total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_batch = (predicted == labels).sum().item()
            total_correct += correct_batch
            total_samples += labels.size(0)
    avg_loss = total_loss / total_samples if criterion and total_samples > 0 else float('nan')
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    return avg_loss, accuracy

# ==============================================================================
# MODULE 6: HÀM HUẤN LUYỆN VÀ VALIDATE (train_and_validate) - Sửa lỗi Warning AMP
# ==============================================================================
def train_and_validate(net, trainloader, valloader, criterion, epochs, accumulation_steps):
    """Huấn luyện model, đánh giá trên validation set sau mỗi epoch."""
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    # !!! Sửa lỗi Warning: Dùng torch.amp.GradScaler !!!
    scaler = torch.amp.GradScaler(device=amp_device_type, enabled=use_gpu)
    # --------------------------------------------------
    print(f"\n--- Starting Training ({epochs} epochs) ---")
    print(f"Effective Batch Size: {trainloader.batch_size * accumulation_steps}")
    print(f"Using Mixed Precision (AMP): {use_gpu}")
    device = next(net.parameters()).device

    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        total_train = 0
        correct_train = 0
        optimizer.zero_grad()

        progress_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f"Epoch {epoch+1}/{epochs} Train", unit="batch")
        for i, (inputs, labels) in progress_bar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # !!! Sửa lỗi Warning: Dùng torch.amp.autocast !!!
            with torch.amp.autocast(device_type=amp_device_type, enabled=use_gpu):
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                if accumulation_steps > 1: loss = loss / accumulation_steps
            # --------------------------------------------------

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(trainloader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_loss = loss.item() * (accumulation_steps if accumulation_steps > 1 else 1)
            running_loss += batch_loss * inputs.size(0)
            total_train += labels.size(0)

            with torch.no_grad():
                 _, predicted = torch.max(outputs.data, 1)
                 correct_train_batch = (predicted == labels).sum().item()
                 batch_acc = correct_train_batch / inputs.size(0) if inputs.size(0) > 0 else 0
                 correct_train += correct_train_batch
                 progress_bar.set_postfix(loss=f"{batch_loss:.4f}", acc=f"{batch_acc:.3f}")

        train_loss = running_loss / total_train if total_train > 0 else float('nan')
        train_acc = correct_train / total_train if total_train > 0 else 0.0

        val_loss, val_acc = evaluate_model(net, valloader, criterion)

        print(f"\nEpoch {epoch+1}/{epochs} Summary -> "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    print("--- Training Finished ---")

# ==============================================================================
# MODULE 7: SCRIPT CHÍNH ĐỂ THỰC THI
# ==============================================================================
if __name__ == "__main__":
    print("="*60)
    print(" PLANT DISEASE CLASSIFICATION TRAINING SCRIPT (PREPROCESSED DATA) ".center(60, "="))
    print("="*60)

    if not os.path.isdir(PROCESSED_DATA_DIR):
        print(f"[FATAL ERROR] Preprocessed data directory not found: {PROCESSED_DATA_DIR}")
        print("Please run the preprocessing script ('preprocess_data*.py') first.")
        exit()

    results = {}

    for model_name in models_to_test:
        print("\n" + "="*50)
        print(f"Processing Model: {model_name}")
        print("="*50 + "\n")

        acc_steps = accumulation_settings.get(model_name, 1)
        model_input_resize = input_sizes[model_name]
        print(f"Config -> Accumulation: {acc_steps}, Model Input Size: {model_input_resize}")

        # --- Tải Dữ liệu ĐÃ TIỀN XỬ LÝ ---
        print("\n--- Loading Preprocessed DataLoaders ---")
        try:
            train_loader, val_loader, test_loader, dset_classes = load_preprocessed_data(
                processed_data_dir=PROCESSED_DATA_DIR,
                batch_size=batch_size,
                num_workers=num_workers_loader,
                pin_memory=use_gpu
            )
            if num_classes is None: num_classes = len(dset_classes)
            elif num_classes != len(dset_classes):
                print(f"[Warning] Class count mismatch. Resetting num_classes to {len(dset_classes)}")
                num_classes = len(dset_classes)
            print(f"Number of classes for model: {num_classes}")

        except Exception as e: print(f"[ERROR] Load data failed: {e}"); traceback.print_exc(); continue

        # --- Cài đặt Model ---
        try:
            print(f"\nSetting up {model_name}...")
            base_model = load_pretrained_core_model(model_name)
            model = ExtendedModel(base_model, num_classes, input_size=model_input_resize)

            # --- Gợi ý tối ưu tốc độ: Thử torch.compile ---
            try:
                # Chỉ dùng nếu PyTorch version >= 2.0
                # print("Attempting to compile model with torch.compile()...")
                # model = torch.compile(model)
                # print("Model compiled successfully.")
                pass # Bỏ comment các dòng trên để kích hoạt
            except Exception as compile_err:
                print(f"[Warning] torch.compile() failed: {compile_err}. Proceeding without compiling.")
            # --------------------------------------------

            if use_gpu:
                if torch.cuda.device_count() > 1: model = nn.DataParallel(model)
                model = model.cuda()

        except Exception as e: print(f"[ERROR] Model setup failed: {e}"); traceback.print_exc(); continue

        # --- Huấn luyện và Đánh giá ---
        criterion = nn.CrossEntropyLoss()
        if use_gpu: criterion = criterion.cuda()

        try:
            train_and_validate(
                net=model, trainloader=train_loader, valloader=val_loader,
                criterion=criterion, epochs=epochs_to_train, accumulation_steps=acc_steps
            )
        except Exception as e: print(f"[ERROR] Train/Val failed: {e}"); traceback.print_exc(); print("Skipping testing."); continue

        # --- Test cuối cùng ---
        print("\n--- Final Testing Phase ---")
        try:
            test_loss, test_acc = evaluate_model(model, test_loader, criterion)
            print(f"\n>>> Test Results for {model_name}: Loss = {test_loss:.4f}, Accuracy = {test_acc:.4f}")
            results[model_name] = {'test_loss': test_loss, 'test_accuracy': test_acc}
        except Exception as e: print(f"[ERROR] Test failed: {e}"); traceback.print_exc(); results[model_name] = {'test_loss': float('nan'), 'test_accuracy': float('nan')}

        print("-"*50)

        # --- Dọn dẹp bộ nhớ ---
        print(f"Cleaning up resources for {model_name}...")
        del model, base_model, train_loader, val_loader, test_loader, criterion
        if use_gpu: torch.cuda.empty_cache()

    # --- In tổng kết ---
    print("\n========================================")
    print(" TRAINING & TESTING COMPLETE ".center(40, "="))
    print("========================================")
    print("Final Test Results Summary:")
    if results:
        for model_name, metrics in results.items():
             print(f"- {model_name}: Test Loss={metrics['test_loss']:.4f}, Test Accuracy={metrics['test_accuracy']:.4f}")
    else:
        print("No models were successfully processed.")
    print("="*40)
    print("SCRIPT FINISHED")

CUDA available. Using 1 GPU(s).
  GPU 0: Tesla T4
 PLANT DISEASE CLASSIFICATION TRAINING SCRIPT (PREPROCESSED DATA) 

Processing Model: efficientnet_b0

Config -> Accumulation: 1, Model Input Size: (224, 224)

--- Loading Preprocessed DataLoaders ---

Loading preprocessed data from: /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed
Loaded 44018 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/train
Loaded 5502 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/val
Loaded 5502 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/test
Classes loaded (39): ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black

Epoch 1/10 Train: 100%|██████████| 1376/1376 [42:17<00:00,  1.84s/batch, acc=0.944, loss=0.1419]
                                                                


Epoch 1/10 Summary -> Train Loss: 0.2657, Train Acc: 0.9229 | Val Loss: 0.1754, Val Acc: 0.9573


Epoch 2/10 Train: 100%|██████████| 1376/1376 [03:37<00:00,  6.33batch/s, acc=0.889, loss=0.3835]
                                                                


Epoch 2/10 Summary -> Train Loss: 0.1456, Train Acc: 0.9574 | Val Loss: 0.0846, Val Acc: 0.9775


Epoch 3/10 Train: 100%|██████████| 1376/1376 [03:40<00:00,  6.25batch/s, acc=1.000, loss=0.0062]
                                                                


Epoch 3/10 Summary -> Train Loss: 0.1054, Train Acc: 0.9702 | Val Loss: 0.0660, Val Acc: 0.9807


Epoch 4/10 Train: 100%|██████████| 1376/1376 [03:40<00:00,  6.24batch/s, acc=1.000, loss=0.0655]
                                                                


Epoch 4/10 Summary -> Train Loss: 0.1093, Train Acc: 0.9693 | Val Loss: 0.7189, Val Acc: 0.9589


Epoch 5/10 Train: 100%|██████████| 1376/1376 [03:39<00:00,  6.28batch/s, acc=0.944, loss=0.1807]
                                                                


Epoch 5/10 Summary -> Train Loss: 0.0835, Train Acc: 0.9765 | Val Loss: 0.0928, Val Acc: 0.9784


Epoch 6/10 Train: 100%|██████████| 1376/1376 [03:56<00:00,  5.83batch/s, acc=0.944, loss=0.2174]
                                                                


Epoch 6/10 Summary -> Train Loss: 0.0964, Train Acc: 0.9745 | Val Loss: 0.1033, Val Acc: 0.9742


Epoch 7/10 Train: 100%|██████████| 1376/1376 [03:44<00:00,  6.13batch/s, acc=1.000, loss=0.0216]
                                                                


Epoch 7/10 Summary -> Train Loss: 0.0735, Train Acc: 0.9796 | Val Loss: 0.2685, Val Acc: 0.9824


Epoch 8/10 Train: 100%|██████████| 1376/1376 [03:40<00:00,  6.24batch/s, acc=0.944, loss=0.0523]
                                                                


Epoch 8/10 Summary -> Train Loss: 0.0681, Train Acc: 0.9816 | Val Loss: 0.1079, Val Acc: 0.9727


Epoch 9/10 Train: 100%|██████████| 1376/1376 [03:41<00:00,  6.22batch/s, acc=1.000, loss=0.0020]
                                                                


Epoch 9/10 Summary -> Train Loss: 0.0709, Train Acc: 0.9809 | Val Loss: 0.0643, Val Acc: 0.9816


Epoch 10/10 Train: 100%|██████████| 1376/1376 [03:40<00:00,  6.23batch/s, acc=1.000, loss=0.0005]



Epoch 10/10 Summary -> Train Loss: 0.0534, Train Acc: 0.9858 | Val Loss: 0.0533, Val Acc: 0.9884
--- Training Finished ---

--- Final Testing Phase ---





>>> Test Results for efficientnet_b0: Loss = 0.0383, Accuracy = 0.9893
--------------------------------------------------
Cleaning up resources for efficientnet_b0...

Processing Model: vit_b_16

Config -> Accumulation: 4, Model Input Size: (224, 224)

--- Loading Preprocessed DataLoaders ---

Loading preprocessed data from: /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed
Loaded 44018 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/train
Loaded 5502 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/val
Loaded 5502 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/test
Classes loaded (39): ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Bligh

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 133MB/s]


ModelSetup: Determined in_features for final layer: 1000

--- Starting Training (10 epochs) ---
Effective Batch Size: 128
Using Mixed Precision (AMP): True


Epoch 1/10 Train: 100%|██████████| 1376/1376 [06:48<00:00,  3.37batch/s, acc=0.833, loss=0.4682]
                                                                


Epoch 1/10 Summary -> Train Loss: 2.0486, Train Acc: 0.4363 | Val Loss: 0.9081, Val Acc: 0.7203


Epoch 2/10 Train: 100%|██████████| 1376/1376 [06:48<00:00,  3.37batch/s, acc=0.778, loss=0.4479]
                                                                


Epoch 2/10 Summary -> Train Loss: 0.6326, Train Acc: 0.8029 | Val Loss: 0.4580, Val Acc: 0.8490


Epoch 3/10 Train: 100%|██████████| 1376/1376 [06:45<00:00,  3.39batch/s, acc=1.000, loss=0.0982]
                                                                


Epoch 3/10 Summary -> Train Loss: 0.3767, Train Acc: 0.8786 | Val Loss: 0.3210, Val Acc: 0.8957


Epoch 4/10 Train: 100%|██████████| 1376/1376 [06:44<00:00,  3.40batch/s, acc=0.944, loss=0.1332]
                                                                


Epoch 4/10 Summary -> Train Loss: 0.2845, Train Acc: 0.9092 | Val Loss: 0.2856, Val Acc: 0.9057


Epoch 5/10 Train: 100%|██████████| 1376/1376 [06:43<00:00,  3.41batch/s, acc=0.833, loss=0.9200]
                                                                


Epoch 5/10 Summary -> Train Loss: 0.2297, Train Acc: 0.9240 | Val Loss: 0.2320, Val Acc: 0.9280


Epoch 6/10 Train:  39%|███▉      | 538/1376 [02:37<04:05,  3.41batch/s, acc=0.812, loss=0.6906]

[ERROR] Train/Val failed: Caught FileNotFoundError in DataLoader worker process 2.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torchvision/datasets/folder.py", line 245, in __getitem__
    sample = self.loader(path)
             ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torchvision/dat


Traceback (most recent call last):
  File "<ipython-input-2-1bf675081f98>", line 323, in <cell line: 0>
    train_and_validate(
  File "<ipython-input-2-1bf675081f98>", line 214, in train_and_validate
    for i, (inputs, labels) in progress_bar:
  File "/usr/local/lib/python3.11/dist-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 708, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1480, in _next_data
    return self._process_data(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1505, in _process_data
    data.reraise()
  File "/usr/local/lib/python3.11/dist-packages/torch/_utils.py", line 733, in reraise
    raise exception
FileNotFoundError: Caught FileNotFoundError in DataLoader worker proc

In [None]:
# file: train_evaluate.py

# ==============================================================================
# MODULE 1: IMPORTS VÀ CÀI ĐẶT BAN ĐẦU
# ==============================================================================
import time
import os
import random
import traceback
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
# !!! Thay đổi import cho AMP !!!
import torch.amp # Import cấp cao hơn
# --------------------------------
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
from tqdm import tqdm

# --- Cài đặt cơ bản ---
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    print(f"CUDA available. Using {torch.cuda.device_count()} GPU(s).")
    # Thêm thông tin về GPU đang dùng (nếu có)
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA not available. Using CPU.")
use_gpu = torch.cuda.is_available()
# Xác định device type cho amp
amp_device_type = 'cuda' if use_gpu else 'cpu'


# --- Cấu hình Model và Training ---
models_to_test = ['vit_b_16']
input_sizes = {
    'vit_b_16': (224, 224)
}
batch_size = 32
# --- Gợi ý tối ưu tốc độ: Thử tăng num_workers ---
num_workers_loader = 4 # <<< Thử tăng lên 4 hoặc 8
# --------------------------------------------------
accumulation_settings = {
    'vit_b_16': 4
}
epochs_to_train = 10
PROCESSED_DATA_DIR = '/content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed' # <<< Đảm bảo đường dẫn đúng
num_classes = None

# ==============================================================================
# MODULE 2: ĐỊNH NGHĨA MODEL (ExtendedModel) - Giữ nguyên
# ==============================================================================
class ExtendedModel(nn.Module):
    def __init__(self, base_model, num_classes, input_size=(224,224)):
        super(ExtendedModel, self).__init__()
        self.base_model = base_model
        probe_device = torch.device("cuda" if use_gpu else "cpu")
        # Tạm chuyển base model sang device để probe output size
        self.base_model.to(probe_device)
        dummy = torch.zeros(1, 3, input_size[0], input_size[1]).to(probe_device)
        with torch.no_grad():
            dummy_out = self.base_model(dummy)
        in_features = dummy_out.shape[1]
        print(f"ModelSetup: Determined in_features for final layer: {in_features}")
        self.extension = nn.Linear(in_features, num_classes)
        # Có thể chuyển base_model về CPU ở đây nếu muốn, nhưng không bắt buộc
        # self.base_model.to('cpu')

    def forward(self, x):
        base_out = self.base_model(x)
        out = self.extension(base_out)
        return out

# ==============================================================================
# MODULE 3: HÀM HỖ TRỢ (Load model pre-trained) - Sửa lỗi Warning
# ==============================================================================
def load_pretrained_core_model(name):
    """Load model pre-trained từ torchvision, ưu tiên dùng weights API."""
    print(f"ModelSetup: Loading pre-trained core model: {name}")
    model = None
    try:
        # Thử tìm Weights enum tương ứng (ví dụ: EfficientNet_B0_Weights)
        weights_enum_name = f"{name.replace('_', '').upper()}_Weights" # Tạo tên enum chuẩn hơn
        weights_enum = getattr(models, weights_enum_name, None)

        if weights_enum:
            # Lấy weights mặc định (thường là tốt nhất)
            weights = weights_enum.DEFAULT
            print(f"Using weights API: {weights}")
            model = getattr(models, name)(weights=weights)
        else:
            # Nếu không có weights enum (model cũ hoặc tên không khớp), thử pretrained=True (sẽ có warning)
            print(f"Weights enum '{weights_enum_name}' not found. Trying legacy pretrained=True (may show warning).")
            model = getattr(models, name)(pretrained=True)

    except AttributeError as e:
        # Xử lý nếu getattr(models, name) thất bại
        print(f"[ERROR] Could not load model '{name}' using standard methods: {e}")
        print("Trying models.__dict__ fallback...")
        try:
            # Fallback rất cũ (ít dùng)
             model = models.__dict__[name](pretrained=True)
        except Exception as e2:
             print(f"[ERROR] Failed loading model '{name}' completely: {e2}")
             raise # Ném lỗi ra ngoài nếu không thể load
    except Exception as e:
        print(f"[ERROR] An unexpected error occurred loading model '{name}': {e}")
        traceback.print_exc()
        raise

    if model is None:
        raise ValueError(f"Could not load model for name: {name}")
    return model


# ==============================================================================
# MODULE 4: HÀM TẢI DỮ LIỆU (load_preprocessed_data) - Giữ nguyên logic
# ==============================================================================
def load_preprocessed_data(processed_data_dir, batch_size, num_workers=4, pin_memory=False):
    """Tải dữ liệu ĐÃ TIỀN XỬ LÝ."""
    print(f"\nLoading preprocessed data from: {processed_data_dir}")
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    data_transforms_online = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        'val_test': transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
    }
    image_datasets = {}
    for x in ['train', 'val', 'test']:
        split_path = os.path.join(processed_data_dir, x)
        if not os.path.isdir(split_path): raise FileNotFoundError(f"Dir not found: {split_path}")
        try:
             transform_key = 'train' if x == 'train' else 'val_test'
             image_datasets[x] = datasets.ImageFolder(split_path, data_transforms_online[transform_key])
             print(f"Loaded {len(image_datasets[x])} images from {split_path}")
             if len(image_datasets[x]) == 0: print(f"[Warning] No images loaded for split '{x}'.")
        except Exception as e: print(f"[ERROR] Load failed for '{x}': {e}"); raise
    dset_classes = image_datasets['train'].classes
    print(f"Classes loaded ({len(dset_classes)}): {dset_classes}")
    dataloaders = {
        x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=(x == 'train'),
                      num_workers=num_workers, pin_memory=pin_memory)
        for x in ['train', 'val', 'test']
    }
    print("Created DataLoaders.")
    return dataloaders['train'], dataloaders['val'], dataloaders['test'], dset_classes

# ==============================================================================
# MODULE 5: HÀM ĐÁNH GIÁ MODEL (evaluate_model) - Sửa lỗi Warning AMP
# ==============================================================================
def evaluate_model(net, dataloader, criterion):
    """Đánh giá model trên dataloader."""
    net.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    device = next(net.parameters()).device
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating", leave=False, unit="batch")
        for inputs, labels in progress_bar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            # !!! Sửa lỗi Warning: Dùng torch.amp.autocast !!!
            with torch.amp.autocast(device_type=amp_device_type, enabled=use_gpu):
                outputs = net(inputs)
                if criterion:
                    loss = criterion(outputs, labels)
                    total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_batch = (predicted == labels).sum().item()
            total_correct += correct_batch
            total_samples += labels.size(0)
    avg_loss = total_loss / total_samples if criterion and total_samples > 0 else float('nan')
    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    return avg_loss, accuracy

# ==============================================================================
# MODULE 6: HÀM HUẤN LUYỆN VÀ VALIDATE (train_and_validate) - Sửa lỗi Warning AMP
# ==============================================================================
def train_and_validate(net, trainloader, valloader, criterion, epochs, accumulation_steps):
    """Huấn luyện model, đánh giá trên validation set sau mỗi epoch."""
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    # !!! Sửa lỗi Warning: Dùng torch.amp.GradScaler !!!
    scaler = torch.amp.GradScaler(device=amp_device_type, enabled=use_gpu)
    # --------------------------------------------------
    print(f"\n--- Starting Training ({epochs} epochs) ---")
    print(f"Effective Batch Size: {trainloader.batch_size * accumulation_steps}")
    print(f"Using Mixed Precision (AMP): {use_gpu}")
    device = next(net.parameters()).device

    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        total_train = 0
        correct_train = 0
        optimizer.zero_grad()

        progress_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f"Epoch {epoch+1}/{epochs} Train", unit="batch")
        for i, (inputs, labels) in progress_bar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # !!! Sửa lỗi Warning: Dùng torch.amp.autocast !!!
            with torch.amp.autocast(device_type=amp_device_type, enabled=use_gpu):
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                if accumulation_steps > 1: loss = loss / accumulation_steps
            # --------------------------------------------------

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(trainloader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            batch_loss = loss.item() * (accumulation_steps if accumulation_steps > 1 else 1)
            running_loss += batch_loss * inputs.size(0)
            total_train += labels.size(0)

            with torch.no_grad():
                 _, predicted = torch.max(outputs.data, 1)
                 correct_train_batch = (predicted == labels).sum().item()
                 batch_acc = correct_train_batch / inputs.size(0) if inputs.size(0) > 0 else 0
                 correct_train += correct_train_batch
                 progress_bar.set_postfix(loss=f"{batch_loss:.4f}", acc=f"{batch_acc:.3f}")

        train_loss = running_loss / total_train if total_train > 0 else float('nan')
        train_acc = correct_train / total_train if total_train > 0 else 0.0

        val_loss, val_acc = evaluate_model(net, valloader, criterion)

        print(f"\nEpoch {epoch+1}/{epochs} Summary -> "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    print("--- Training Finished ---")

# ==============================================================================
# MODULE 7: SCRIPT CHÍNH ĐỂ THỰC THI
# ==============================================================================
if __name__ == "__main__":
    print("="*60)
    print(" PLANT DISEASE CLASSIFICATION TRAINING SCRIPT (PREPROCESSED DATA) ".center(60, "="))
    print("="*60)

    if not os.path.isdir(PROCESSED_DATA_DIR):
        print(f"[FATAL ERROR] Preprocessed data directory not found: {PROCESSED_DATA_DIR}")
        print("Please run the preprocessing script ('preprocess_data*.py') first.")
        exit()

    results = {}

    for model_name in models_to_test:
        print("\n" + "="*50)
        print(f"Processing Model: {model_name}")
        print("="*50 + "\n")

        acc_steps = accumulation_settings.get(model_name, 1)
        model_input_resize = input_sizes[model_name]
        print(f"Config -> Accumulation: {acc_steps}, Model Input Size: {model_input_resize}")

        # --- Tải Dữ liệu ĐÃ TIỀN XỬ LÝ ---
        print("\n--- Loading Preprocessed DataLoaders ---")
        try:
            train_loader, val_loader, test_loader, dset_classes = load_preprocessed_data(
                processed_data_dir=PROCESSED_DATA_DIR,
                batch_size=batch_size,
                num_workers=num_workers_loader,
                pin_memory=use_gpu
            )
            if num_classes is None: num_classes = len(dset_classes)
            elif num_classes != len(dset_classes):
                print(f"[Warning] Class count mismatch. Resetting num_classes to {len(dset_classes)}")
                num_classes = len(dset_classes)
            print(f"Number of classes for model: {num_classes}")

        except Exception as e: print(f"[ERROR] Load data failed: {e}"); traceback.print_exc(); continue

        # --- Cài đặt Model ---
        try:
            print(f"\nSetting up {model_name}...")
            base_model = load_pretrained_core_model(model_name)
            model = ExtendedModel(base_model, num_classes, input_size=model_input_resize)

            # --- Gợi ý tối ưu tốc độ: Thử torch.compile ---
            try:
                # Chỉ dùng nếu PyTorch version >= 2.0
                # print("Attempting to compile model with torch.compile()...")
                # model = torch.compile(model)
                # print("Model compiled successfully.")
                pass # Bỏ comment các dòng trên để kích hoạt
            except Exception as compile_err:
                print(f"[Warning] torch.compile() failed: {compile_err}. Proceeding without compiling.")
            # --------------------------------------------

            if use_gpu:
                if torch.cuda.device_count() > 1: model = nn.DataParallel(model)
                model = model.cuda()

        except Exception as e: print(f"[ERROR] Model setup failed: {e}"); traceback.print_exc(); continue

        # --- Huấn luyện và Đánh giá ---
        criterion = nn.CrossEntropyLoss()
        if use_gpu: criterion = criterion.cuda()

        try:
            train_and_validate(
                net=model, trainloader=train_loader, valloader=val_loader,
                criterion=criterion, epochs=epochs_to_train, accumulation_steps=acc_steps
            )
        except Exception as e: print(f"[ERROR] Train/Val failed: {e}"); traceback.print_exc(); print("Skipping testing."); continue

        # --- Test cuối cùng ---
        print("\n--- Final Testing Phase ---")
        try:
            test_loss, test_acc = evaluate_model(model, test_loader, criterion)
            print(f"\n>>> Test Results for {model_name}: Loss = {test_loss:.4f}, Accuracy = {test_acc:.4f}")
            results[model_name] = {'test_loss': test_loss, 'test_accuracy': test_acc}
        except Exception as e: print(f"[ERROR] Test failed: {e}"); traceback.print_exc(); results[model_name] = {'test_loss': float('nan'), 'test_accuracy': float('nan')}

        print("-"*50)

        # --- Dọn dẹp bộ nhớ ---
        print(f"Cleaning up resources for {model_name}...")
        del model, base_model, train_loader, val_loader, test_loader, criterion
        if use_gpu: torch.cuda.empty_cache()

    # --- In tổng kết ---
    print("\n========================================")
    print(" TRAINING & TESTING COMPLETE ".center(40, "="))
    print("========================================")
    print("Final Test Results Summary:")
    if results:
        for model_name, metrics in results.items():
             print(f"- {model_name}: Test Loss={metrics['test_loss']:.4f}, Test Accuracy={metrics['test_accuracy']:.4f}")
    else:
        print("No models were successfully processed.")
    print("="*40)
    print("SCRIPT FINISHED")

CUDA not available. Using CPU.
 PLANT DISEASE CLASSIFICATION TRAINING SCRIPT (PREPROCESSED DATA) 

Processing Model: vit_b_16

Config -> Accumulation: 4, Model Input Size: (224, 224)

--- Loading Preprocessed DataLoaders ---

Loading preprocessed data from: /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed
Loaded 44018 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/train
Loaded 5502 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/val
Loaded 5502 images from /content/drive/MyDrive/Colab Notebooks/PlantVillage_Processed/test
Classes loaded (39): ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Blac

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:04<00:00, 80.8MB/s]


ModelSetup: Determined in_features for final layer: 1000

--- Starting Training (10 epochs) ---
Effective Batch Size: 128
Using Mixed Precision (AMP): False


Epoch 1/10 Train:   0%|          | 4/1376 [06:09<35:10:29, 92.30s/batch, acc=0.062, loss=3.9003]


KeyboardInterrupt: 