# CvT Fine-tuning untuk Klasifikasi Penyakit Tanaman Padi

Notebook ini akan membantu Anda melakukan fine-tuning model CvT-21 untuk klasifikasi penyakit tanaman padi dengan 10 kelas (9 penyakit + 1 normal).

## Persiapan Manual:
1. **Clone repository** ini ke Colab
2. **Download dataset** Anda dan letakkan di `/content/CvT/paddy_disease_dataset/`
3. **Download pretrained weights** `CvT-21-224x224-IN-1k.pth` dan letakkan di `/content/CvT/`

### Struktur yang diharapkan:
```
/content/CvT/
‚îú‚îÄ‚îÄ paddy_disease_dataset/
‚îÇ   ‚îú‚îÄ‚îÄ train/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ class1/
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ class2/
‚îÇ   ‚îî‚îÄ‚îÄ val/
‚îÇ       ‚îú‚îÄ‚îÄ class1/
‚îÇ       ‚îî‚îÄ‚îÄ class2/
‚îî‚îÄ‚îÄ CvT-21-224x224-IN-1k.pth
```

## 1. Mount Google Drive

In [None]:
# Clone repository CvT
!git clone https://github.com/raviearjun/CvT.git
%cd CvT

# Install dependencies
!pip install -r requirements.txt

# Verify installation
import torch
import torchvision
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

print("\n‚ö†Ô∏è  MANUAL SETUP REQUIRED:")
print("1. Download your paddy disease dataset")
print("2. Extract to: /content/CvT/paddy_disease_dataset/")
print("3. Download CvT-21-224x224-IN-1k.pth to: /content/CvT/")
print("4. Then continue to next cell")

## 1. Clone Repository dan Setup Environment

In [None]:
# Clone repository CvT
!git clone https://github.com/microsoft/CvT.git
%cd CvT

# Install dependencies
!pip install -r requirements.txt

# Verify installation
import torch
import torchvision
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Verifikasi Dataset dan Weights

In [None]:
import os

# Buat direktori untuk output
!mkdir -p /content/output

# Verify dataset structure
print("üîç Checking dataset...")
dataset_path = "/content/CvT/paddy_disease_dataset"
weights_path = "/content/CvT/CvT-21-224x224-IN-1k.pth"

if os.path.exists(dataset_path):
    print("‚úÖ Dataset directory found")
    
    # Check train, val, and test directories
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(dataset_path, split)
        if os.path.exists(split_path):
            classes = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
            print(f"‚úÖ {split}: {len(classes)} classes found")
            if len(classes) > 0:
                print(f"   Classes: {classes}")
        else:
            print(f"‚ùå {split} directory not found!")
else:
    print("‚ùå Dataset directory not found!")
    print("   Please ensure dataset is placed at:", dataset_path)

# Verify weights file
if os.path.exists(weights_path):
    print("‚úÖ Pretrained weights found")
else:
    print("‚ùå Pretrained weights not found!")
    print("   Please ensure weights file is at:", weights_path)

print("\nüìä File structure:")
!ls -la /content/CvT/

## 3. Verifikasi Konfigurasi

In [None]:
# Tampilkan konfigurasi yang akan digunakan
!cat /content/CvT/experiments/imagenet/cvt/cvt-21-224x224_paddy_dataset.yaml

## 4. Hitung Jumlah Data per Kelas

In [None]:
import os

def count_images_per_class(dataset_path):
    """Hitung jumlah gambar per kelas"""
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(dataset_path, split)
        if os.path.exists(split_path):
            print(f"\n{split.upper()} Dataset:")
            print("-" * 30)
            total = 0
            for class_name in sorted(os.listdir(split_path)):
                class_path = os.path.join(split_path, class_name)
                if os.path.isdir(class_path):
                    count = len([f for f in os.listdir(class_path) 
                               if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
                    print(f"{class_name}: {count} images")
                    total += count
            print(f"Total: {total} images")
        else:
            print(f"\n{split.upper()} Dataset: NOT FOUND")

count_images_per_class('/content/CvT/paddy_disease_dataset')

## 6. Mulai Training

In [None]:
# Mulai training dengan konfigurasi cvt-21-224x224.yaml yang sudah diedit
!python tools/train.py \
    --cfg experiments/imagenet/cvt/cvt-21-224x224.yaml \
    --output /content/output \
    --log-dir /content/output

## 7. Monitor Training Progress

In [None]:
# Lihat log training terakhir
!tail -n 50 /content/output/log.txt

## 8. Evaluasi Model Terbaik

In [None]:
# Cari model dengan accuracy terbaik
!ls -la /content/output/

    "# Evaluasi model terbaik
",
    "!python tools/test.py 
",
    "    --cfg experiments/imagenet/cvt/cvt-21-224x224.yaml 
",
    "    --model-file /content/output/best.pth 
",
    "    --output /content/output/test_results"

## 9. Visualisasi Hasil Training

In [None]:
import matplotlib.pyplot as plt
import re

def parse_log_file(log_path):
    """Parse log file untuk mendapatkan loss dan accuracy"""
    epochs = []
    train_losses = []
    val_accs = []
    
    try:
        with open(log_path, 'r') as f:
            lines = f.readlines()
            
        for line in lines:
            # Parse training loss
            if 'Epoch:' in line and 'Loss' in line:
                epoch_match = re.search(r'Epoch: \[(\d+)\]', line)
                loss_match = re.search(r'Loss ([\d\.]+)', line)
                if epoch_match and loss_match:
                    epoch = int(epoch_match.group(1))
                    loss = float(loss_match.group(1))
                    if epoch not in [e for e, _, _ in zip(epochs, train_losses, val_accs)]:
                        epochs.append(epoch)
                        train_losses.append(loss)
            
            # Parse validation accuracy
            if 'Test:' in line and 'Acc@1' in line:
                acc_match = re.search(r'Acc@1 ([\d\.]+)', line)
                if acc_match:
                    acc = float(acc_match.group(1))
                    if len(val_accs) < len(epochs):
                        val_accs.append(acc)
    
    except FileNotFoundError:
        print("Log file not found. Training might still be in progress.")
        return [], [], []
    
    return epochs, train_losses, val_accs

# Parse log dan plot hasil
epochs, train_losses, val_accs = parse_log_file('/content/output/log.txt')

if epochs:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot training loss
    ax1.plot(epochs[:len(train_losses)], train_losses, 'b-', label='Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot validation accuracy
    if val_accs:
        ax2.plot(epochs[:len(val_accs)], val_accs, 'r-', label='Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.set_title('Validation Accuracy')
        ax2.legend()
        ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Training completed for {len(epochs)} epochs")
    if val_accs:
        print(f"Best validation accuracy: {max(val_accs):.2f}%")
else:
    print("No training data found in log file.")

## 10. Backup Model ke Google Drive

In [None]:
# Buat archive untuk download hasil training
import shutil
from google.colab import files

print("üì¶ Preparing results for download...")

# Create archive of training results
archive_name = "cvt_paddy_training_results"
shutil.make_archive(f"/content/{archive_name}", 'zip', '/content/output')

# Copy config file to output for archiving
!cp /content/CvT/experiments/imagenet/cvt/cvt-21-224x224_paddy_dataset.yaml /content/output/

print("‚úÖ Results archived successfully!")
print("\nüìÅ Training results summary:")
!ls -la /content/output/

print(f"\nüíæ Download your results:")
print("1. Training archive (semua file):")
files.download(f"/content/{archive_name}.zip")

print("\n2. Download individual files:")
print("   - Best model: /content/output/best.pth")
print("   - Latest model: /content/output/latest.pth") 
print("   - Training log: /content/output/log.txt")

# Optionally download individual important files
download_individual = input("\nDownload individual files? (y/n): ")
if download_individual.lower() == 'y':
    try:
        files.download("/content/output/best.pth")
        files.download("/content/output/log.txt")
        print("‚úÖ Individual files downloaded!")
    except:
        print("‚ÑπÔ∏è  Some files may not exist yet or download was cancelled")

## 11. Inference pada Gambar Baru

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# Daftar kelas penyakit padi (sesuaikan dengan dataset Anda)
class_names = [
    'bacterial_leaf_blight',
    'bacterial_leaf_streak', 
    'bacterial_panicle_blight',
    'blast',
    'brown_spot',
    'dead_heart',
    'downy_mildew',
    'hispa',
    'normal',
    'tungro'
]  # Ubah sesuai dengan kelas dataset Anda

def predict_image(image_path, model_path):
    """Prediksi kelas untuk satu gambar"""
    # Load model
    # Implementasi ini memerlukan modifikasi pada tools/test.py untuk inference
    # Untuk saat ini, gunakan tools/test.py untuk evaluasi batch
    
    "    print(f"Untuk melakukan inference pada gambar tunggal, gunakan:")
",
    "    print(f"python tools/test.py --cfg experiments/imagenet/cvt/cvt-21-224x224.yaml --model-file {model_path} --image {image_path}")"

# Contoh penggunaan (uncomment jika Anda memiliki gambar test)
# predict_image('/content/test_image.jpg', '/content/output/best.pth')