# üåø Plant Disease Classification v·ªõi MambaTSR

## Google Colab Setup
Notebook n√†y ƒë∆∞·ª£c t·ªëi ∆∞u cho Google Colab v·ªõi GPU mi·ªÖn ph√≠.

**Tr∆∞·ªõc khi ch·∫°y:**
1. Runtime ‚Üí Change runtime type ‚Üí GPU (T4)
2. Upload d·ªØ li·ªáu PlantVillage ho·∫∑c mount Google Drive

In [None]:
# ============= COLAB SETUP =============
import sys

# Check if running on Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("‚úì Running on Google Colab")
    
    # Mount Google Drive (optional)
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Clone MambaTSR repository
    !git clone https://github.com/quoclam-doit/Plant_Disease.git
    %cd Plant_Disease/MambaTSR
    
    # Install dependencies
    !pip install -q timm einops fvcore tensorboard
    
    # Compile selective_scan CUDA kernel
    %cd kernels/selective_scan
    !pip install -e .
    %cd ../..
else:
    print("‚ö†Ô∏è  Not running on Colab - make sure environment is set up correctly")

In [None]:
# Import libraries
import os
import sys
import time
import math
import copy
import random
from pathlib import Path
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau

from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

# Add MambaTSR to path
if IN_COLAB:
    sys.path.insert(0, '/content/Plant_Disease/MambaTSR')

from models.VSSBlock_utils import Super_Mamba

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## üìä Upload Dataset

B·∫°n c√≥ 3 options:
1. Upload t·ª´ m√°y local (ch·∫≠m, ~2GB)
2. L∆∞u trong Google Drive v√† mount (khuy·∫øn ngh·ªã)
3. Download t·ª´ link public

In [None]:
# Option 1: Upload from local (uncomment if needed)
# from google.colab import files
# uploaded = files.upload()

# Option 2: Use from Google Drive
if IN_COLAB:
    DATA_ROOT = '/content/drive/MyDrive/PlantVillage'  # Adjust path
else:
    DATA_ROOT = r'G:\Dataset\Data\PlantVillage\PlantVillage-Dataset-master'

print(f"Data root: {DATA_ROOT}")

# Verify dataset
if not Path(DATA_ROOT).exists():
    print("‚ö†Ô∏è  Dataset not found! Please upload or adjust DATA_ROOT path")
else:
    print(f"‚úì Dataset found with {len(list(Path(DATA_ROOT).iterdir()))} classes")

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
CONFIG = {
    'model_name': 'Super_Mamba',
    'dims': 3,
    'depth': 6,
    'num_classes': 39,
    
    'batch_size': 64,
    'num_epochs': 50,  # Reduced for Colab time limits
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'num_workers': 2,
    'pin_memory': True,
    
    'image_size': 32,
    'brightness': 0.8,
    'contrast': (1.0, 1.0),
    
    'scheduler': 'cosine',
    'min_lr': 1e-6,
    'patience': 15,
    
    'data_root': Path(DATA_ROOT),
    'save_dir': Path('/content/models' if IN_COLAB else 'G:/Dataset/models/MambaTSR'),
}

CONFIG['save_dir'].mkdir(parents=True, exist_ok=True)
print(f"\n‚úì Models will be saved to: {CONFIG['save_dir']}")

In [None]:
# Data transforms
transform_train = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ColorJitter(brightness=CONFIG['brightness']),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_val = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load dataset
full_dataset = datasets.ImageFolder(root=CONFIG['data_root'])
class_names = full_dataset.classes
num_classes = len(class_names)

print(f"‚úì Found {num_classes} classes")
print(f"‚úì Total images: {len(full_dataset):,}")

# Split dataset
total_size = len(full_dataset)
train_size = int(0.72 * total_size)
val_size = int(0.18 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

train_dataset.dataset.transform = transform_train
val_dataset.dataset.transform = transform_val
test_dataset.dataset.transform = transform_val

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                          shuffle=True, num_workers=CONFIG['num_workers'],
                          pin_memory=CONFIG['pin_memory'])
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                        shuffle=False, num_workers=CONFIG['num_workers'],
                        pin_memory=CONFIG['pin_memory'])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'],
                         shuffle=False, num_workers=CONFIG['num_workers'],
                         pin_memory=CONFIG['pin_memory'])

print(f"\nüìä Dataset split:")
print(f"  Train: {len(train_dataset):,} ({len(train_dataset)/total_size*100:.1f}%)")
print(f"  Val:   {len(val_dataset):,} ({len(val_dataset)/total_size*100:.1f}%)")
print(f"  Test:  {len(test_dataset):,} ({len(test_dataset)/total_size*100:.1f}%)")

In [None]:
# Initialize model
model = Super_Mamba(
    dims=CONFIG['dims'],
    depth=CONFIG['depth'],
    num_classes=CONFIG['num_classes']
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì Model: {CONFIG['model_name']}")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

In [None]:
# Test forward pass
print("Testing forward pass...")
model.eval()
with torch.no_grad():
    test_input = torch.randn(2, 3, CONFIG['image_size'], CONFIG['image_size']).to(device)
    test_output = model(test_input)
    print(f"‚úì Input shape: {test_input.shape}")
    print(f"‚úì Output shape: {test_output.shape}")
    print(f"‚úì Model is ready for training!")

## üöÄ Training

**Note:** Training s·∫Ω m·∫•t ~2-4 gi·ªù tr√™n Colab T4 GPU

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=CONFIG['learning_rate'], 
                  weight_decay=CONFIG['weight_decay'])

if CONFIG['scheduler'] == 'cosine':
    scheduler = CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs'], 
                                  eta_min=CONFIG['min_lr'])
else:
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                                  patience=5, verbose=True)

# Training history
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'lr': []
}

best_val_acc = 0.0
patience_counter = 0

print("Starting training...")
print(f"Total epochs: {CONFIG['num_epochs']}")
print(f"Batch size: {CONFIG['batch_size']}")

In [None]:
# Training loop
for epoch in range(CONFIG['num_epochs']):
    # Train
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{100.*train_correct/train_total:.2f}%"
        })
    
    train_loss /= len(train_dataset)
    train_acc = 100. * train_correct / train_total
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
    
    val_loss /= len(val_dataset)
    val_acc = 100. * val_correct / val_total
    
    # Update scheduler
    if CONFIG['scheduler'] == 'cosine':
        scheduler.step()
    else:
        scheduler.step(val_loss)
    
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    
    # Print epoch results
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    print(f"  LR: {current_lr:.6f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'config': CONFIG
        }, CONFIG['save_dir'] / 'super_mamba_best.pth')
        print(f"  ‚úì Saved best model (Val Acc: {val_acc:.2f}%)")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= CONFIG['patience']:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break

print("\n‚úì Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

## üìà Results & Evaluation

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(CONFIG['save_dir'] / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Training curves saved to {CONFIG['save_dir'] / 'training_curves.png'}")

In [None]:
# Load best model and evaluate on test set
checkpoint = torch.load(CONFIG['save_dir'] / 'super_mamba_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

test_preds = []
test_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        
        test_preds.extend(predicted.cpu().numpy())
        test_labels.extend(labels.numpy())

test_acc = accuracy_score(test_labels, test_preds)
print(f"\nüéØ Test Accuracy: {test_acc*100:.2f}%")

# Classification report
print("\n" + "="*80)
print("Classification Report:")
print("="*80)
print(classification_report(test_labels, test_preds, target_names=class_names))

## üíæ Download Model

ƒê·ªÉ download model v·ªÅ m√°y local:

In [None]:
if IN_COLAB:
    from google.colab import files
    
    # Download best model
    model_path = str(CONFIG['save_dir'] / 'super_mamba_best.pth')
    files.download(model_path)
    print(f"‚úì Downloaded: {model_path}")
    
    # Download training curves
    curve_path = str(CONFIG['save_dir'] / 'training_curves.png')
    files.download(curve_path)
    print(f"‚úì Downloaded: {curve_path}")