# ‚ö†Ô∏è IMPORTANT: Setup Requirements

## üõ†Ô∏è Prerequisites Before Running This Notebook

### ‚úÖ Already Installed:
- ‚úÖ PyTorch 2.6.0 with CUDA 12.4
- ‚úÖ Virtual environment (`.venv`)
- ‚úÖ MambaTSR repository cloned

### ‚è≥ Still Need to Install:

#### 1. Microsoft Visual C++ Build Tools (Required!)
**Why needed:** To compile CUDA kernels for selective scan operation

**Download:** https://visualstudio.microsoft.com/visual-cpp-build-tools/

**Install workloads:**
- ‚úÖ Desktop development with C++
- ‚úÖ MSVC v143 - VS 2022 C++ x64/x86 build tools
- ‚úÖ Windows 10 SDK

**Time needed:** ~30-40 minutes (download + install + compile)

#### 2. Compile selective_scan kernel
After installing Build Tools, run in terminal:

```powershell
cd G:\Dataset\MambaTSR\kernels\selective_scan
pip install --no-build-isolation -e .
```

---

## üìã Full Setup Guide

**See:** `G:\Dataset\MAMBATSR_SETUP_GUIDE.md` for detailed instructions

---

## ‚ö° Quick Start (After Setup)

Once Build Tools is installed and selective_scan is compiled:
1. Restart VS Code
2. Select kernel: `.venv` (Python 3.11)
3. Run all cells below

---

**Current Status:** ‚è≥ Waiting for Build Tools installation

# MambaTSR for PlantVillage Disease Classification
## √Åp d·ª•ng ki·∫øn tr√∫c MambaTSR (State Space Model) ƒë·ªÉ nh·∫≠n d·∫°ng b·ªánh c√¢y

**Model**: Super_Mamba v·ªõi Vision State Space (VSS) Blocks  
**Dataset**: PlantVillage (39 classes)  
**Paper**: "MambaTSR: You Only Need 90k Parameters for Traffic Sign Recognition" (Neurocomputing, JCR Q1)  

### Ki·∫øn tr√∫c ch√≠nh:
- **ConvNet**: Embedding ban ƒë·∫ßu t·ª´ ·∫£nh RGB
- **PatchMerging2D + VSSBlock**: 6 t·∫ßng (depth=6)
- **SS2D (Selective Scan 2D)**: Tr√°i tim c·ªßa Mamba architecture
- **Classifier**: LayerNorm ‚Üí AvgPool ‚Üí Linear(num_classes=39)

## 1. Setup Environment & Dependencies

In [1]:
# Import standard libraries
import os
import sys
import time
import math
import copy
import random
import logging
from functools import partial
from typing import Optional, Callable, Any
from collections import OrderedDict
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set environment variables for RTX 5060 Ti (sm_120) compatibility with PyTorch 2.6.0
# Force CUDA to use PTX JIT compilation for unsupported compute capability
os.environ['CUDA_FORCE_PTX_JIT'] = '1'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Import PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau

# Import torchvision
from torchvision import datasets, transforms
from torchvision.ops import Permute

# Import timm
import timm
from timm.models.layers import DropPath, trunc_normal_

# Import other utilities
from einops import rearrange, repeat
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
GPU: NVIDIA GeForce RTX 5060 Ti
GPU Memory: 15.93 GB


NVIDIA GeForce RTX 5060 Ti with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90.
If you want to use the NVIDIA GeForce RTX 5060 Ti GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



## 2. Add MambaTSR to Python Path

In [2]:
# Add MambaTSR models to path
mamba_path = Path(r'G:\Dataset\MambaTSR')
if str(mamba_path) not in sys.path:
    sys.path.insert(0, str(mamba_path))
    print(f"‚úì Added {mamba_path} to sys.path")

# Verify path
print("\nPython search paths:")
for p in sys.path[:3]:
    print(f"  - {p}")

‚úì Added G:\Dataset\MambaTSR to sys.path

Python search paths:
  - G:\Dataset\MambaTSR
  - C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\python311.zip
  - C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\DLLs


## 3. Import MambaTSR Components

**Quan tr·ªçng**: ƒê√¢y l√† ph·∫ßn th·∫ßy y√™u c·∫ßu - copy c√°c components t·ª´ MambaTSR

In [3]:
# Import MambaTSR core components
try:
    from models.ConvNet import ConvNet
    from models.VSSBlock import VSSBlock
    from models.vmamba import SS2D, Mlp
    print("‚úì Successfully imported MambaTSR components")
except ImportError as e:
    print(f"‚ùå Error importing MambaTSR components: {e}")
    print("\nƒê·∫£m b·∫£o ƒë√£ c√†i ƒë·∫∑t selective_scan CUDA kernel:")
    print("cd G:\\Dataset\\MambaTSR\\kernels\\selective_scan")
    print("pip install .")
    raise

‚úì Successfully imported MambaTSR components


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  @torch.cuda.amp.custom_bwd


## 4. Define PatchMerging2D

Component t·ª´ VSSBlock_utils.py - gi·∫£m spatial dimension v√† tƒÉng channels

In [4]:
class PatchMerging2D(nn.Module):
    """Patch Merging Layer - gi·∫£m H, W xu·ªëng 1/2 v√† tƒÉng channels"""
    def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
        self.norm = norm_layer(4 * dim)

    @staticmethod
    def _patch_merging_pad(x: torch.Tensor):
        H, W, _ = x.shape[-3:]
        if (W % 2 != 0) or (H % 2 != 0):
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
        x0 = x[..., 0::2, 0::2, :]  # ... H/2 W/2 C
        x1 = x[..., 1::2, 0::2, :]  # ... H/2 W/2 C
        x2 = x[..., 0::2, 1::2, :]  # ... H/2 W/2 C
        x3 = x[..., 1::2, 1::2, :]  # ... H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # ... H/2 W/2 4*C
        return x

    def forward(self, x):
        x = self._patch_merging_pad(x)
        x = self.norm(x)
        x = self.reduction(x)
        return x

print("‚úì PatchMerging2D defined")

‚úì PatchMerging2D defined


## 5. Define Super_Mamba Model

**ƒê√¢y l√† d√≤ng 59 m√† th·∫ßy y√™u c·∫ßu** - Class Super_Mamba t·ª´ VSSBlock_utils.py  
**Ch·ªânh s·ª≠a**: `num_classes=39` cho PlantVillage (thay v√¨ 43 cho traffic signs)

In [5]:
class PermuteLayer(nn.Module):
    """Permute layer for dimension reordering"""
    def __init__(self, *args):
        super().__init__()
        self.args = args

    def forward(self, x: torch.Tensor):
        return x.permute(*self.args)


class Super_Mamba(nn.Module):
    """Super_Mamba Model cho PlantVillage Classification
    
    Ki·∫øn tr√∫c:
    - ConvNet: Feature embedding ban ƒë·∫ßu
    - 6 stages: PatchMerging2D + VSSBlock (Selective Scan 2D)
    - Classifier: LayerNorm ‚Üí Permute ‚Üí AvgPool ‚Üí Linear
    
    Args:
        dims: S·ªë channels ban ƒë·∫ßu (default=3 cho RGB)
        depth: S·ªë t·∫ßng VSSBlock (default=6)
        num_classes: S·ªë classes output (39 cho PlantVillage)
    """
    def __init__(self, dims=3, depth=6, num_classes=39):
        super().__init__()
        self.depth = depth
        self.preembd = ConvNet()  # Embedding layer
        
        # Calculate dimensions for each layer
        if isinstance(dims, int):
            dims = [int(dims * 2 ** i_layer) for i_layer in range(self.depth+1)]
        self.num_features = dims[-1]
        self.dims = dims
        
        # Build layers: PatchMerging + VSSBlock repeated depth times
        self.layers = nn.ModuleList()
        for i_layer in range(self.depth):
            downsample = PatchMerging2D(
                self.dims[i_layer],
                self.dims[i_layer + 1],
                norm_layer=nn.LayerNorm,
            )
            vss_block = VSSBlock(hidden_dim=self.dims[i_layer+1])
            self.layers.append(downsample)
            self.layers.append(vss_block)

        # Classifier head
        self.classifier = nn.Sequential(OrderedDict(
            norm=nn.LayerNorm(self.num_features),  # B,H,W,C
            permute=PermuteLayer(0, 3, 1, 2),
            avgpool=nn.AdaptiveAvgPool2d(1),
            flatten=nn.Flatten(1),
            head=nn.Linear(self.num_features, num_classes),
        ))

        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Module):
        """Initialize weights"""
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = self.preembd(x)  # ConvNet embedding
        x = x.permute(0, 2, 3, 1)  # [B, C, H, W] -> [B, H, W, C]
        for layers in self.layers:
            x = layers(x)
        x = self.classifier(x)
        return x

print("‚úì Super_Mamba model defined v·ªõi num_classes=39 cho PlantVillage")

‚úì Super_Mamba model defined v·ªõi num_classes=39 cho PlantVillage


## 6. Configuration & Hyperparameters

In [6]:
# Device configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úì Using GPU: {torch.cuda.get_device_name(0)}")
    print(f"  - VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"  - Compute Capability: {'.'.join(map(str, torch.cuda.get_device_capability(0)))}")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è  CUDA not available, using CPU")
print(f"Using device: {device}")

# Hyperparameters
CONFIG = {
    # Model
    'model_name': 'Super_Mamba',
    'dims': 3,  # Initial channels
    'depth': 6,  # Number of VSSBlock stages
    'num_classes': 39,  # PlantVillage classes
    
    # Training
    'batch_size': 64 if torch.cuda.is_available() else 32,  # 64 for GPU, 32 for CPU
    'num_epochs': 100,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'num_workers': 0,  # Windows compatibility
    'pin_memory': torch.cuda.is_available(),  # True for GPU, False for CPU
    
    # Data augmentation
    'image_size': 32,  # MambaTSR s·ª≠ d·ª•ng 32x32
    'brightness': 0.8,
    'contrast': (1.0, 1.0),
    
    # Learning rate scheduler
    'scheduler': 'cosine',  # 'cosine' or 'plateau'
    'min_lr': 1e-6,
    
    # Early stopping
    'patience': 15,
    
    # Paths
    'data_root': Path(r'G:\Dataset\Data\PlantVillage\PlantVillage-Dataset-master'),
    'save_dir': Path(r'G:\Dataset\models\MambaTSR'),
}

# Create save directory
CONFIG['save_dir'].mkdir(parents=True, exist_ok=True)
print(f"\n‚úì Models will be saved to: {CONFIG['save_dir']}")
print(f"\nüìã Configuration:")
for key, value in CONFIG.items():
    if not isinstance(value, Path):
        print(f"  {key}: {value}")

‚úì Using GPU: NVIDIA GeForce RTX 5060 Ti
  - VRAM: 15.9 GB
  - Compute Capability: 12.0
Using device: cuda

‚úì Models will be saved to: G:\Dataset\models\MambaTSR

üìã Configuration:
  model_name: Super_Mamba
  dims: 3
  depth: 6
  num_classes: 39
  batch_size: 64
  num_epochs: 100
  learning_rate: 0.001
  weight_decay: 0.0001
  num_workers: 0
  pin_memory: True
  image_size: 32
  brightness: 0.8
  contrast: (1.0, 1.0)
  scheduler: cosine
  min_lr: 1e-06
  patience: 15


## 7. Data Preparation

In [7]:
# Data transforms - theo c·∫•u tr√∫c c·ªßa MambaTSR
transform_train = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ColorJitter(brightness=CONFIG['brightness'], contrast=CONFIG['contrast']),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_val = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

print("‚úì Data transforms defined")
print(f"  - Train: ColorJitter + Flip + Rotation + Normalize")
print(f"  - Val/Test: Resize + Normalize")

‚úì Data transforms defined
  - Train: ColorJitter + Flip + Rotation + Normalize
  - Val/Test: Resize + Normalize


In [8]:
# Load datasets
print("Loading PlantVillage dataset...")

full_dataset = datasets.ImageFolder(root=CONFIG['data_root'])
class_names = full_dataset.classes
num_classes = len(class_names)

print(f"‚úì Found {num_classes} classes: {class_names[:5]}...")
print(f"‚úì Total images: {len(full_dataset):,}")

# Split dataset: 72% train, 18% val, 10% test
total_size = len(full_dataset)
train_size = int(0.72 * total_size)
val_size = int(0.18 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Apply transforms
train_dataset.dataset.transform = transform_train
val_dataset.dataset.transform = transform_val
test_dataset.dataset.transform = transform_val

print(f"\nüìä Dataset split:")
print(f"  - Train: {len(train_dataset):,} images ({len(train_dataset)/total_size*100:.1f}%)")
print(f"  - Val:   {len(val_dataset):,} images ({len(val_dataset)/total_size*100:.1f}%)")
print(f"  - Test:  {len(test_dataset):,} images ({len(test_dataset)/total_size*100:.1f}%)")

Loading PlantVillage dataset...
‚úì Found 39 classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy']...
‚úì Total images: 54,305

üìä Dataset split:
  - Train: 39,099 images (72.0%)
  - Val:   9,774 images (18.0%)
  - Test:  5,432 images (10.0%)
‚úì Found 39 classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy']...
‚úì Total images: 54,305

üìä Dataset split:
  - Train: 39,099 images (72.0%)
  - Val:   9,774 images (18.0%)
  - Test:  5,432 images (10.0%)


In [9]:
# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

print("\n‚úì DataLoaders created")
print(f"  - Train batches: {len(train_loader)}")
print(f"  - Val batches: {len(val_loader)}")
print(f"  - Test batches: {len(test_loader)}")


‚úì DataLoaders created
  - Train batches: 611
  - Val batches: 153
  - Test batches: 85


## 8. DataLoader Smoke Test

In [10]:
# Test DataLoader
print("Testing DataLoader...")
test_batch = next(iter(train_loader))
images, labels = test_batch

print(f"‚úì Batch shape: {images.shape}")
print(f"‚úì Labels shape: {labels.shape}")
print(f"‚úì Image range: [{images.min():.3f}, {images.max():.3f}]")
print(f"‚úì Device: {images.device}")

if torch.cuda.is_available():
    images_gpu = images.to(device)
    print(f"‚úì Successfully moved batch to GPU: {images_gpu.device}")

Testing DataLoader...
‚úì Batch shape: torch.Size([64, 3, 32, 32])
‚úì Labels shape: torch.Size([64])
‚úì Image range: [-1.964, 2.448]
‚úì Device: cpu
‚úì Successfully moved batch to GPU: cuda:0
‚úì Batch shape: torch.Size([64, 3, 32, 32])
‚úì Labels shape: torch.Size([64])
‚úì Image range: [-1.964, 2.448]
‚úì Device: cpu
‚úì Successfully moved batch to GPU: cuda:0


  images_gpu = images.to(device)


## 9. Initialize Model

In [11]:
# Create Super_Mamba model
print("Initializing Super_Mamba model...")
model = Super_Mamba(
    dims=CONFIG['dims'],
    depth=CONFIG['depth'],
    num_classes=CONFIG['num_classes']
).to(device)

print(f"‚úì Model created and moved to {device}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüìä Model Statistics:")
print(f"  - Total parameters: {total_params:,}")
print(f"  - Trainable parameters: {trainable_params:,}")
print(f"  - Model size: {total_params * 4 / 1024**2:.2f} MB (float32)")

Initializing Super_Mamba model...
‚úì Model created and moved to cuda

üìä Model Statistics:
  - Total parameters: 1,333,636
  - Trainable parameters: 1,333,636
  - Model size: 5.09 MB (float32)


In [12]:
# Test forward pass
print("\nTesting forward pass...")
model.eval()
with torch.no_grad():
    test_input = torch.randn(2, 3, CONFIG['image_size'], CONFIG['image_size']).to(device)
    test_output = model(test_input)
    print(f"‚úì Input shape: {test_input.shape}")
    print(f"‚úì Output shape: {test_output.shape}")
    print(f"‚úì Output range: [{test_output.min():.3f}, {test_output.max():.3f}]")

model.train()
print("‚úì Forward pass successful!")


Testing forward pass...


RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## 10. Training Setup

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer - theo paper MambaTSR
optimizer = AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

# Learning rate scheduler
if CONFIG['scheduler'] == 'cosine':
    scheduler = CosineAnnealingLR(
        optimizer,
        T_max=CONFIG['num_epochs'],
        eta_min=CONFIG['min_lr']
    )
    print("‚úì Using CosineAnnealingLR scheduler")
else:
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=5,
        min_lr=CONFIG['min_lr']
    )
    print("‚úì Using ReduceLROnPlateau scheduler")

print(f"‚úì Optimizer: AdamW (lr={CONFIG['learning_rate']}, weight_decay={CONFIG['weight_decay']})")
print(f"‚úì Loss function: CrossEntropyLoss")

## 11. Training Functions

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
    for batch_idx, (images, labels) in enumerate(pbar):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{running_loss/(batch_idx+1):.4f}",
            'acc': f"{100.*correct/total:.2f}%"
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device, epoch, phase='Val'):
    """Validate model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [{phase}]")
        for batch_idx, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f"{running_loss/(batch_idx+1):.4f}",
                'acc': f"{100.*correct/total:.2f}%"
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

print("‚úì Training functions defined")

## 12. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'lr': []
}

# Best model tracking
best_val_acc = 0.0
best_epoch = 0
patience_counter = 0

print("\n" + "="*80)
print(f"üöÄ Starting Training: Super_Mamba on PlantVillage")
print("="*80)
print(f"Epochs: {CONFIG['num_epochs']} | Batch size: {CONFIG['batch_size']} | LR: {CONFIG['learning_rate']}")
print(f"Device: {device} | Early stopping patience: {CONFIG['patience']}")
print("="*80 + "\n")

start_time = time.time()

for epoch in range(CONFIG['num_epochs']):
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device, epoch
    )
    
    # Validate
    val_loss, val_acc = validate(
        model, val_loader, criterion, device, epoch, phase='Val'
    )
    
    # Update learning rate
    if CONFIG['scheduler'] == 'cosine':
        scheduler.step()
    else:
        scheduler.step(val_acc)
    
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    print(f"  LR: {current_lr:.6f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        patience_counter = 0
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'config': CONFIG
        }
        
        torch.save(checkpoint, CONFIG['save_dir'] / 'super_mamba_best.pth')
        print(f"  ‚úì New best model saved! Val Acc: {val_acc:.2f}%")
    else:
        patience_counter += 1
        print(f"  No improvement ({patience_counter}/{CONFIG['patience']})")
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch+1} epochs")
        print(f"Best Val Acc: {best_val_acc:.2f}% at epoch {best_epoch}")
        break
    
    print("-" * 80)

# Training complete
elapsed_time = time.time() - start_time
print("\n" + "="*80)
print("‚úì Training Complete!")
print("="*80)
print(f"Total time: {elapsed_time/3600:.2f} hours")
print(f"Best Val Acc: {best_val_acc:.2f}% at epoch {best_epoch}")
print(f"Model saved to: {CONFIG['save_dir'] / 'super_mamba_best.pth'}")
print("="*80)

## 13. Save Final Model

In [None]:
# Save final model
final_checkpoint = {
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'val_acc': val_acc,
    'val_loss': val_loss,
    'history': history,
    'config': CONFIG
}

torch.save(final_checkpoint, CONFIG['save_dir'] / 'super_mamba_final.pth')
print(f"‚úì Final model saved to: {CONFIG['save_dir'] / 'super_mamba_final.pth'}")

## 14. Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train Acc', linewidth=2)
axes[1].plot(history['val_acc'], label='Val Acc', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning Rate
axes[2].plot(history['lr'], linewidth=2, color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(CONFIG['save_dir'] / 'training_history.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Training curves saved to: {CONFIG['save_dir'] / 'training_history.png'}")

## 15. Load Best Model & Test

In [None]:
# Load best model
print("Loading best model for testing...")
checkpoint = torch.load(CONFIG['save_dir'] / 'super_mamba_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"‚úì Loaded best model from epoch {checkpoint['epoch']}")
print(f"  Val Acc: {checkpoint['val_acc']:.2f}%")

In [None]:
# Test evaluation
print("\n" + "="*80)
print("üß™ Testing on Test Set")
print("="*80)

model.eval()
test_loss = 0.0
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    pbar = tqdm(test_loader, desc="Testing")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({
            'acc': f"{100.*correct/total:.2f}%"
        })

test_loss /= len(test_loader)
test_acc = 100. * correct / total

print("\n" + "="*80)
print("üìä Test Results:")
print("="*80)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")
print(f"Correct: {correct:,} / {total:,}")
print("="*80)

## 16. Detailed Classification Report

In [None]:
# Classification report
from sklearn.metrics import classification_report

print("\nüìã Classification Report:\n")
print(classification_report(
    all_labels, 
    all_preds, 
    target_names=class_names,
    digits=4
))

## 17. Confusion Matrix

In [None]:
# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(20, 18))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    cbar_kws={'label': 'Count'}
)
plt.title('Confusion Matrix - Super_Mamba on PlantVillage', fontsize=16, pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(CONFIG['save_dir'] / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Confusion matrix saved to: {CONFIG['save_dir'] / 'confusion_matrix.png'}")

## 18. Save Results Summary

In [None]:
# Save results to JSON
import json

results = {
    'model_name': 'Super_Mamba',
    'dataset': 'PlantVillage',
    'num_classes': CONFIG['num_classes'],
    'total_params': total_params,
    'best_epoch': best_epoch,
    'best_val_acc': best_val_acc,
    'test_acc': test_acc,
    'test_loss': test_loss,
    'training_time_hours': elapsed_time / 3600,
    'config': {k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()}
}

with open(CONFIG['save_dir'] / 'results_summary.json', 'w') as f:
    json.dump(results, f, indent=4)

print(f"‚úì Results saved to: {CONFIG['save_dir'] / 'results_summary.json'}")
print("\n‚úÖ All done! MambaTSR training completed successfully.")