In [1]:
"""
Train Hybrid CNN+ViT Model for Wheat Disease Detection
=====================================================

This script trains a hybrid model that combines ConvNeXt (CNN) and ViT (Vision Transformer) with attention-based feature fusion, matching the workflow of train_convnext.ipynb.

- Uses torchvision ConvNeXt and timm ViT
- Attention-based fusion of CNN and ViT features
- Saves best model to saved_models_and_data/best_hybrid_model.pth

How to use:
-----------
1. Ensure your dataset is split into train/val/test in dataset_split/ (see README for format)
2. Run: python train_hybrid_cnn_vit.py

Author: [Your Name]
Date: [Date]
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
import time

try:
    import timm
except ImportError:
    import sys
    os.system(f"{sys.executable} -m pip install timm")
    import timm
from timm import create_model

from torch.utils.data import WeightedRandomSampler, Subset
import shutil

# -----------------------------
# Configuration
# -----------------------------
DATASET_DIR = 'dataset'
SAVE_DIR = 'saved_models_and_data'
SPLIT_OUTPUT_DIR = 'dataset_split'
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-4
EARLY_STOPPING_PATIENCE = 5
USE_MIXED_PRECISION = True if torch.cuda.is_available() else False

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(SPLIT_OUTPUT_DIR, exist_ok=True)

# -----------------------------
# Transformations
# -----------------------------
train_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(45),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# -----------------------------
# Custom Dataset
# -----------------------------
class WheatDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Only include subdirectories as classes
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.samples = []
        for target_class in self.classes:
            class_dir = os.path.join(root_dir, target_class)
            for img_file in os.listdir(class_dir):
                path = os.path.join(class_dir, img_file)
                self.samples.append((path, self.class_to_idx[target_class]))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, target = self.samples[idx]
        try:
            image = Image.open(path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, target
        except Exception as e:
            print(f"Erreur lors du chargement de {path}: {e}")
            return self.__getitem__((idx + 1) % len(self))

# -----------------------------
# Hybrid Model Definition
# -----------------------------
class AttentionFusion(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.Tanh(),
            nn.Linear(in_dim, 1)
        )
    def forward(self, x):
        # x: (batch, features)
        attn_weights = torch.softmax(self.attn(x), dim=1)
        return (x * attn_weights).sum(dim=1)

class HybridCNNViT(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # CNN branch (ConvNeXt)
        self.cnn = models.convnext_base(pretrained=True)
        cnn_out = self.cnn.classifier[2].in_features
        self.cnn.classifier = nn.Identity()
        # ViT branch (from timm)
        self.vit = create_model('vit_base_patch16_224', pretrained=True)
        vit_out = self.vit.head.in_features
        self.vit.head = nn.Identity()
        # Attention-based fusion
        self.fusion = nn.Sequential(
            nn.Linear(cnn_out + vit_out, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        cnn_feat = self.cnn.features(x)  # (batch, C, H, W)
        cnn_feat = cnn_feat.mean(dim=[2, 3])  # Global average pool to (batch, C)
        vit_feat = self.vit(x)  # (batch, features)
        fused = torch.cat([cnn_feat, vit_feat], dim=1)
        out = self.fusion(fused)
        return out

# -----------------------------
# Data Loaders
# -----------------------------
def get_data_loaders():
    split_dirs = [os.path.join(SPLIT_OUTPUT_DIR, split) for split in ['train', 'val', 'test']]
    split_exists = all(os.path.isdir(d) and len(os.listdir(d)) > 0 for d in split_dirs)
    if split_exists:
        print('Found existing split dataset. Loading splits...')
        train_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'train'), transform=train_transform)
        val_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'val'), transform=test_transform)
        test_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'test'), transform=test_transform)
    else:
        print('No split dataset found. Splitting and saving images...')
        full_dataset = WheatDiseaseDataset(DATASET_DIR, transform=train_transform)
        generator = torch.Generator().manual_seed(42)
        indices = torch.randperm(len(full_dataset), generator=generator).tolist()
        train_size = int(0.7 * len(full_dataset))
        val_size = int(0.15 * len(full_dataset))
        test_size = len(full_dataset) - train_size - val_size
        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size + val_size]
        test_indices = indices[train_size + val_size:]
        train_data = Subset(full_dataset, train_indices)
        val_data = Subset(full_dataset, val_indices)
        test_data = Subset(full_dataset, test_indices)
        def save_split_images(dataset, indices, split_name):
            print(f"Saving images for split: {split_name}")
            for idx in indices:
                path, label_idx = dataset.dataset.samples[idx]  # dataset is a Subset
                class_name = dataset.dataset.classes[label_idx]
                filename = os.path.basename(path)
                dest_dir = os.path.join(SPLIT_OUTPUT_DIR, split_name, class_name)
                os.makedirs(dest_dir, exist_ok=True)
                dest_path = os.path.join(dest_dir, filename)
                shutil.copyfile(path, dest_path)
        save_split_images(train_data, train_indices, 'train')
        save_split_images(val_data, val_indices, 'val')
        save_split_images(test_data, test_indices, 'test')
        print('Image splits saved.')
        train_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'train'), transform=train_transform)
        val_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'val'), transform=test_transform)
        test_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'test'), transform=test_transform)
    print('Calculating class weights for balanced sampling...')
    targets = [s[1] for s in train_dataset.samples]
    class_counts = np.bincount(targets)
    class_weights = 1. / class_counts
    sample_weights = [class_weights[t] for t in targets]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    print('Data loaders are ready.')
    return train_loader, val_loader, test_loader, train_dataset.classes

# -----------------------------
# Training Loop
# -----------------------------
def train_model(model, device, train_loader, val_loader, num_epochs=EPOCHS):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
    best_acc = 0.0
    no_improvement_epochs = 0
    scaler = torch.cuda.amp.GradScaler(enabled=USE_MIXED_PRECISION)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            autocast_device = 'cuda' if torch.cuda.is_available() else 'cpu'
            with torch.amp.autocast(autocast_device, enabled=USE_MIXED_PRECISION):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.float() / len(train_loader.dataset)
        # Validation
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                with torch.amp.autocast(autocast_device, enabled=USE_MIXED_PRECISION):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                val_loss += loss.item() * inputs.size(0)
                val_corrects += torch.sum(preds == labels.data)
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_corrects.float() / len(val_loader.dataset)
        scheduler.step(val_loss)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
        # Early Stopping
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_hybrid_model.pth"))
            no_improvement_epochs = 0
        else:
            no_improvement_epochs += 1
        if no_improvement_epochs >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break
    print("Training complete.")
    return model

# -----------------------------
# Main
# -----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, val_loader, test_loader, class_labels = get_data_loaders()
model = HybridCNNViT(num_classes=len(class_labels)).to(device)
model = train_model(model, device, train_loader, val_loader)

# Save final model
torch.save(model.state_dict(), os.path.join(SAVE_DIR, "final_hybrid_model.pth")) 

  from .autonotebook import tqdm as notebook_tqdm


Found existing split dataset. Loading splits...
Calculating class weights for balanced sampling...
Data loaders are ready.


Error while downloading from https://cdn-lfs.hf.co/repos/fb/cd/fbcdc88e492959e3ee2515f497fb45cb5f217f9455c204bdf4e0b400c90d0c23/32aa17d6e17b43500f531d5f6dc9bc93e56ed8841b8a75682e1bb295d722405b?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1753078349&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MzA3ODM0OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mYi9jZC9mYmNkYzg4ZTQ5Mjk1OWUzZWUyNTE1ZjQ5N2ZiNDVjYjVmMjE3Zjk0NTVjMjA0YmRmNGUwYjQwMGM5MGQwYzIzLzMyYWExN2Q2ZTE3YjQzNTAwZjUzMWQ1ZjZkYzliYzkzZTU2ZWQ4ODQxYjhhNzU2ODJlMWJiMjk1ZDcyMjQwNWI%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vk-yrHOGnMCEOhSDb7DHxb1eLxZIPLky-Uz8d-TzTxPzeUIO%7EwUsN2qQw-2UqoQ4KMHOgjnB0YOzNW2YmPdvSVVya1Lo5T4lNrbi%7EonlVZa3VFFpPDRBYrdyjZLkCh81SJrxLxxT55bhmIMGEHU7CzKvWRrOQlDaQOE8gl2r7u6wrzulBy3F2H7UJPLm8mAaw5rq3bmYnqMxObdhGmzn3Z0rPzkFg8vLIOWfvatnjHi-F61K-B8NdGgaIGh0Kxu5bP98Ypi

Epoch 1/10 | Train Loss: 0.9812 Acc: 0.6644 | Val Loss: 0.5653 Acc: 0.8067


KeyboardInterrupt: 