# CvT Fine-tuning untuk Klasifikasi Penyakit Tanaman Padi

Notebook ini akan membantu Anda melakukan fine-tuning model CvT-21 untuk klasifikasi penyakit tanaman padi dengan 10 kelas (9 penyakit + 1 normal).

## Persiapan:
1. Dataset dalam format ImageFolder sudah di Google Drive
2. File pretrained weights `CvT-21-224x224-IN-1k.pth` sudah di Google Drive
3. Repository CvT sudah di-clone ke Colab

## 1. Mount Google Drive

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Verify mount
print("Google Drive mounted successfully!")
print("Available files in Drive:")
!ls /content/drive/MyDrive/

## 2. Clone Repository dan Setup Environment

In [None]:
# Clone repository CvT
!git clone https://github.com/microsoft/CvT.git
%cd CvT

# Install dependencies
!pip install -r requirements.txt

# Verify installation
import torch
import torchvision
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Setup Dataset dan Model Weights

In [None]:
    "# Buat direktori untuk output
",
    "!mkdir -p /content/output
",
    "
",
    "# Copy dataset dari Google Drive ke repo CvT
",
    "# Sesuaikan path ini dengan lokasi dataset Anda di Google Drive
",
    "dataset_path = "/content/drive/MyDrive/paddy_disease_classification"  # Ubah sesuai lokasi dataset Anda
",
    "!cp -r "$dataset_path" /content/CvT/
",
    "
",
    "# Copy pretrained weights ke root repo CvT
",
    "weights_path = "/content/drive/MyDrive/CvT-21-224x224-IN-1k.pth"  # Ubah sesuai lokasi weights Anda
",
    "!cp "$weights_path" /content/CvT/
",
    "
",
    "# Verify dataset structure
",
    "print("Dataset structure:")
",
    "!ls -la /content/CvT/paddy_disease_classification/
",
    "print("
Train classes:")
",
    "!ls /content/CvT/paddy_disease_classification/train/
",
    "print("
Validation classes:")
",
    "!ls /content/CvT/paddy_disease_classification/val/
",
    "
",
    "# Verify weights file
",
    "print("
Pretrained weights:")
",
    "!ls -la /content/CvT/*.pth"

## 4. Verifikasi Konfigurasi

In [None]:
    "# Tampilkan konfigurasi yang akan digunakan
",
    "!cat /content/CvT/experiments/imagenet/cvt/cvt-21-224x224.yaml"

## 5. Hitung Jumlah Data per Kelas

In [None]:
import os

    "def count_images_per_class(dataset_path):
",
    "    """Hitung jumlah gambar per kelas"""
",
    "    for split in ['train', 'val']:
",
    "        split_path = os.path.join(dataset_path, split)
",
    "        if os.path.exists(split_path):
",
    "            print(f"
{split.upper()} Dataset:")
",
    "            print("-" * 30)
",
    "            total = 0
",
    "            for class_name in sorted(os.listdir(split_path)):
",
    "                class_path = os.path.join(split_path, class_name)
",
    "                if os.path.isdir(class_path):
",
    "                    count = len([f for f in os.listdir(class_path) 
",
    "                               if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
",
    "                    print(f"{class_name}: {count} images")
",
    "                    total += count
",
    "            print(f"Total: {total} images")
",
    "
",
    "count_images_per_class('/content/CvT/paddy_disease_classification')"

## 6. Mulai Training

In [None]:
# Mulai training dengan konfigurasi cvt-21-224x224.yaml yang sudah diedit
!python tools/train.py \
    --cfg experiments/imagenet/cvt/cvt-21-224x224.yaml \
    --output /content/output \
    --log-dir /content/output

## 7. Monitor Training Progress

In [None]:
# Lihat log training terakhir
!tail -n 50 /content/output/log.txt

## 8. Evaluasi Model Terbaik

In [None]:
# Cari model dengan accuracy terbaik
!ls -la /content/output/

    "# Evaluasi model terbaik
",
    "!python tools/test.py 
",
    "    --cfg experiments/imagenet/cvt/cvt-21-224x224.yaml 
",
    "    --model-file /content/output/best.pth 
",
    "    --output /content/output/test_results"

## 9. Visualisasi Hasil Training

In [None]:
import matplotlib.pyplot as plt
import re

def parse_log_file(log_path):
    """Parse log file untuk mendapatkan loss dan accuracy"""
    epochs = []
    train_losses = []
    val_accs = []
    
    try:
        with open(log_path, 'r') as f:
            lines = f.readlines()
            
        for line in lines:
            # Parse training loss
            if 'Epoch:' in line and 'Loss' in line:
                epoch_match = re.search(r'Epoch: \[(\d+)\]', line)
                loss_match = re.search(r'Loss ([\d\.]+)', line)
                if epoch_match and loss_match:
                    epoch = int(epoch_match.group(1))
                    loss = float(loss_match.group(1))
                    if epoch not in [e for e, _, _ in zip(epochs, train_losses, val_accs)]:
                        epochs.append(epoch)
                        train_losses.append(loss)
            
            # Parse validation accuracy
            if 'Test:' in line and 'Acc@1' in line:
                acc_match = re.search(r'Acc@1 ([\d\.]+)', line)
                if acc_match:
                    acc = float(acc_match.group(1))
                    if len(val_accs) < len(epochs):
                        val_accs.append(acc)
    
    except FileNotFoundError:
        print("Log file not found. Training might still be in progress.")
        return [], [], []
    
    return epochs, train_losses, val_accs

# Parse log dan plot hasil
epochs, train_losses, val_accs = parse_log_file('/content/output/log.txt')

if epochs:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot training loss
    ax1.plot(epochs[:len(train_losses)], train_losses, 'b-', label='Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot validation accuracy
    if val_accs:
        ax2.plot(epochs[:len(val_accs)], val_accs, 'r-', label='Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.set_title('Validation Accuracy')
        ax2.legend()
        ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Training completed for {len(epochs)} epochs")
    if val_accs:
        print(f"Best validation accuracy: {max(val_accs):.2f}%")
else:
    print("No training data found in log file.")

## 10. Backup Model ke Google Drive

In [None]:
# Buat direktori backup di Google Drive
!mkdir -p "/content/drive/MyDrive/CvT_Paddy_Results"

# Copy hasil training ke Google Drive
!cp -r /content/output/* "/content/drive/MyDrive/CvT_Paddy_Results/"

# Copy config file yang digunakan
    "# Copy config file yang digunakan
",
    "!cp /content/CvT/experiments/imagenet/cvt/cvt-21-224x224.yaml "/content/drive/MyDrive/CvT_Paddy_Results/""

print("Model dan hasil training berhasil disimpan ke Google Drive!")
print("Lokasi: /content/drive/MyDrive/CvT_Paddy_Results/")

# Tampilkan file yang disimpan
!ls -la "/content/drive/MyDrive/CvT_Paddy_Results/"

## 11. Inference pada Gambar Baru

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# Daftar kelas penyakit padi (sesuaikan dengan dataset Anda)
class_names = [
    'bacterial_leaf_blight',
    'bacterial_leaf_streak', 
    'bacterial_panicle_blight',
    'blast',
    'brown_spot',
    'dead_heart',
    'downy_mildew',
    'hispa',
    'normal',
    'tungro'
]  # Ubah sesuai dengan kelas dataset Anda

def predict_image(image_path, model_path):
    """Prediksi kelas untuk satu gambar"""
    # Load model
    # Implementasi ini memerlukan modifikasi pada tools/test.py untuk inference
    # Untuk saat ini, gunakan tools/test.py untuk evaluasi batch
    
    "    print(f"Untuk melakukan inference pada gambar tunggal, gunakan:")
",
    "    print(f"python tools/test.py --cfg experiments/imagenet/cvt/cvt-21-224x224.yaml --model-file {model_path} --image {image_path}")"

# Contoh penggunaan (uncomment jika Anda memiliki gambar test)
# predict_image('/content/test_image.jpg', '/content/output/best.pth')