In [None]:
import os
import sys
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import shutil
from PIL import Image
import time
import warnings
warnings.filterwarnings('ignore')

# Configuration - Matching other training scripts exactly
DATASET_DIR = '../dataset'
SAVE_DIR = '../saved_models_and_data'
SPLIT_OUTPUT_DIR = '../dataset_split'
IMAGE_SIZE = (224, 224)  # Standard size used by other scripts
BATCH_SIZE = 32          # Standard batch size used by other scripts
TEST_SIZE = 0.15
VAL_SIZE = 0.15
EPOCHS = 10
LEARNING_RATE = 1e-4     # Standard learning rate used by other scripts
EARLY_STOPPING_PATIENCE = 5  # Standard patience used by other scripts
USE_MIXED_PRECISION = True if torch.cuda.is_available() else False

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(SPLIT_OUTPUT_DIR, exist_ok=True)

def set_seed(seed=42):
    """Sets the seed for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# Install YOLOv9 if not present
try:
    import ultralytics
except ImportError:
    print("Installing ultralytics (YOLOv9)...")
    os.system(f"{sys.executable} -m pip install ultralytics")
    import ultralytics

from ultralytics import YOLO

print('Defining custom dataset class for wheat disease images...')
class WheatDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.num_workers = 4
        self.pin_memory = True
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.samples = []
        for target_class in self.classes:
            class_dir = os.path.join(root_dir, target_class)
            # Ensure it's a directory before listing
            if os.path.isdir(class_dir):
                for img_file in os.listdir(class_dir):
                    if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        path = os.path.join(class_dir, img_file)
                        self.samples.append((path, self.class_to_idx[target_class]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, target = self.samples[idx]
        try:
            image = Image.open(path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, target
        except Exception as e:
            print(f"Error loading {path}: {e}")
            # Fallback: try to load the next image in the dataset to avoid stopping
            return self.__getitem__((idx + 1) % len(self))
print('Custom dataset class defined.')

def create_yolo_labels(dataset_dir, output_dir, class_labels):
    """Convert classification dataset to YOLO object detection format"""
    print(f"Converting classification dataset to YOLO format...")
    
    # Create output directories
    for split in ['train', 'val', 'test']:
        split_dir = os.path.join(output_dir, split)
        os.makedirs(split_dir, exist_ok=True)
        
        # Create images and labels subdirectories
        images_dir = os.path.join(split_dir, 'images')
        labels_dir = os.path.join(split_dir, 'labels')
        os.makedirs(images_dir, exist_ok=True)
        os.makedirs(labels_dir, exist_ok=True)
        
        # Get the split directory path
        split_source_dir = os.path.join(dataset_dir, split)
        if not os.path.exists(split_source_dir):
            print(f"Split directory {split_source_dir} does not exist, skipping...")
            continue
            
        # Process each class
        for class_name in class_labels:
            class_dir = os.path.join(split_source_dir, class_name)
            
            # Check if it's a directory and exists
            if not os.path.exists(class_dir):
                print(f"Class directory {class_dir} does not exist, skipping...")
                continue
                
            if not os.path.isdir(class_dir):
                print(f"Skipping {class_dir} - not a directory")
                continue
                
            class_idx = class_labels.index(class_name)
            print(f"Processing class {class_name} (index {class_idx}) in {split} split...")
            
            # Process each image in the class
            try:
                image_count = 0
                for img_file in os.listdir(class_dir):
                    # Only process image files
                    if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        # Copy image to images directory
                        src_path = os.path.join(class_dir, img_file)
                        dst_path = os.path.join(images_dir, img_file)
                        shutil.copy2(src_path, dst_path)
                        
                        # Create YOLO label file
                        label_file = os.path.splitext(img_file)[0] + '.txt'
                        label_path = os.path.join(labels_dir, label_file)
                        
                        # Create bounding box for entire image (classification converted to detection)
                        # Format: class_id x_center y_center width height (normalized)
                        with open(label_path, 'w') as f:
                            # x_center = 0.5, y_center = 0.5, width = 1.0, height = 1.0
                            # This creates a bounding box covering the entire image
                            f.write(f"{class_idx} 0.5 0.5 1.0 1.0\n")
                        image_count += 1
                
                print(f"Processed {image_count} images for class {class_name} in {split} split")
                
            except NotADirectoryError:
                print(f"Skipping {class_dir} - not a directory")
                continue
            except Exception as e:
                print(f"Error processing {class_dir}: {e}")
                continue
    
    print("YOLO format conversion completed.")

print('Preparing data loaders and splitting dataset if needed...')
def get_data_loaders():
    split_dirs = [os.path.join(SPLIT_OUTPUT_DIR, split) for split in ['train', 'val', 'test']]
    # Check if split directories exist and are not empty
    split_exists = all(os.path.isdir(d) and len([f for f in os.listdir(d) if os.path.isdir(os.path.join(d, f))]) > 0 for d in split_dirs)

    if split_exists:
        print('Found existing split dataset. Loading splits...')
        train_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'train'), transform=None)
        val_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'val'), transform=None)
        test_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'test'), transform=None)
    else:
        print('No split dataset found. Splitting and saving images...')
        full_dataset = WheatDiseaseDataset(DATASET_DIR, transform=None) # Use None transform for initial loading

        generator = torch.Generator().manual_seed(42)
        indices = torch.randperm(len(full_dataset), generator=generator).tolist()

        train_size = int((1 - TEST_SIZE - VAL_SIZE) * len(full_dataset))
        val_size = int(VAL_SIZE * len(full_dataset))
        # test_size is implicitly the rest
        # test_size = len(full_dataset) - train_size - val_sizemat

        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size + val_size]
        test_indices = indices[train_size + val_size:]

        # Create Subset objects for initial splitting (before saving to disk)
        train_data_subset = Subset(full_dataset, train_indices)
        val_data_subset = Subset(full_dataset, val_indices)
        test_data_subset = Subset(full_dataset, test_indices)

        def save_split_images(dataset_subset, split_name):
            print(f"Saving images for split: {split_name}")
            for idx_in_subset in range(len(dataset_subset)):
                original_idx = dataset_subset.indices[idx_in_subset] # Get original index from subset
                path, label_idx = full_dataset.samples[original_idx] # Use full_dataset to get original path
                class_name = full_dataset.classes[label_idx]
                filename = os.path.basename(path)
                dest_dir = os.path.join(SPLIT_OUTPUT_DIR, split_name, class_name)
                os.makedirs(dest_dir, exist_ok=True)
                dest_path = os.path.join(dest_dir, filename)
                shutil.copyfile(path, dest_path)

        save_split_images(train_data_subset, 'train')
        save_split_images(val_data_subset, 'val')
        save_split_images(test_data_subset, 'test')
        print('Image splits saved.')

        # Now, load datasets from the new split directories
        train_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'train'), transform=None)
        val_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'val'), transform=None)
        test_dataset = WheatDiseaseDataset(os.path.join(SPLIT_OUTPUT_DIR, 'test'), transform=None)

    print('Data loaders are ready.')
    return train_dataset, val_dataset, test_dataset

def create_yolo_yaml(class_labels):
    """Create YAML configuration file for YOLOv9"""
    DATA_YAML = os.path.join(SAVE_DIR, 'wheat_yolov9.yaml')
    
    # Create YOLO format directories and labels
    yolo_dataset_dir = os.path.join(SAVE_DIR, 'yolo_dataset')
    create_yolo_labels(SPLIT_OUTPUT_DIR, yolo_dataset_dir, class_labels)
    
    with open(DATA_YAML, 'w') as f:
        f.write(f"train: {os.path.abspath(os.path.join(yolo_dataset_dir, 'train', 'images'))}\n")
        f.write(f"val: {os.path.abspath(os.path.join(yolo_dataset_dir, 'val', 'images'))}\n")
        f.write(f"test: {os.path.abspath(os.path.join(yolo_dataset_dir, 'test', 'images'))}\n")
        f.write(f"nc: {len(class_labels)}\n")
        f.write(f"names: {class_labels}\n")
    
    print(f"YAML configuration created: {DATA_YAML}")
    print(f"Classes ({len(class_labels)}): {class_labels}")
    return DATA_YAML

def train_model(data_yaml, num_classes, class_labels):
    """Train YOLOv9 model following the same pattern as other training scripts"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    if torch.cuda.is_available():
        print(f"CUDA device count: {torch.cuda.device_count()}")
        print(f"Current CUDA device: {torch.cuda.current_device()}")
        print(f"CUDA device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    
    # Initialize YOLOv9 model
    model = YOLO('yolov9c.pt')  # Download pretrained weights
    
    best_acc = 0.0
    no_improvement_epochs = 0
    train_log = []
    
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True  # optimize for speed
    
    print("Starting YOLOv9 training...")
    start_time = time.time()
    
    # Training parameters matching other scripts' approach
    results = model.train(
        data=data_yaml,
        imgsz=IMAGE_SIZE[0],  # Use same image size as other scripts
        epochs=EPOCHS,
        batch=BATCH_SIZE,
        lr0=LEARNING_RATE,  # Use same learning rate as other scripts
        project=SAVE_DIR,
        name='yolov9_wheat_disease',
        exist_ok=True,
        save=True,
        save_period=5,  # Save checkpoint every 5 epochs
        patience=EARLY_STOPPING_PATIENCE,  # Same patience as other scripts
        verbose=True,
        workers=0,  # Set to 0 to avoid multiprocessing issues
        device=device,
        # Data augmentation parameters matching other scripts
        hsv_h=0.1,      # Hue augmentation - similar to ColorJitter
        hsv_s=0.2,      # Saturation augmentation - similar to ColorJitter
        hsv_v=0.2,      # Value augmentation - similar to ColorJitter
        degrees=45,     # Rotation degrees - matching RandomRotation(45)
        translate=0.1,  # Translation - minimal
        scale=0.1,      # Scale variation - minimal  
        shear=0.0,      # No shear to match other scripts
        perspective=0.0, # No perspective to match other scripts
        flipud=0.5,     # Vertical flip - matching RandomVerticalFlip
        fliplr=0.5,     # Horizontal flip - matching RandomHorizontalFlip
        mosaic=0.0,     # Disable mosaic for classification-like training
        mixup=0.0,      # Disable mixup for classification-like training
        copy_paste=0.0, # Disable copy-paste
    )
    
    total_time = time.time() - start_time
    print(f"Training completed in {total_time/60:.1f} minutes")
    
    return model, results

def evaluate_model(model, data_yaml):
    """Evaluate the trained model on test set following other scripts' pattern"""
    print("Evaluating model on test set...")
    
    # Validation on test set
    results = model.val(
        data=data_yaml,
        split='test',
        imgsz=IMAGE_SIZE[0],
        batch=BATCH_SIZE,
        save_json=True,
        save_hybrid=True,
        conf=0.001,
        iou=0.6,
        max_det=300,
        half=False,
        device=0 if torch.cuda.is_available() else 'cpu',
        dnn=False,
        plots=True,
        verbose=True
    )
    
    return results

def main():
    """Main execution following the same pattern as other training scripts"""
    print('Loading data...')
    train_dataset, val_dataset, test_dataset = get_data_loaders()
    class_labels = train_dataset.classes
    NUM_CLASSES = len(class_labels)
    print('Data loaded. Classes:', class_labels)
    
    print('Creating YAML configuration...')
    data_yaml = create_yolo_yaml(class_labels)
    
    print('Initializing YOLOv9 model...')
    print('Model initialized. Starting training...')
    model, train_results = train_model(data_yaml, NUM_CLASSES, class_labels)
    
    print('Training complete. Evaluating on test set...')
    eval_results = evaluate_model(model, data_yaml)
    
    # Save final model following other scripts' pattern
    best_model_path = os.path.join(SAVE_DIR, 'yolov9_wheat_disease', 'weights', 'best.pt')
    final_model_path = os.path.join(SAVE_DIR, 'wheat_disease_yolov9_model.pt')
    
    if os.path.exists(best_model_path):
        shutil.copy(best_model_path, final_model_path)
        print(f'Model saved to {final_model_path}')
        
        # Also save a backup following other scripts' naming
        backup_path = os.path.join(SAVE_DIR, 'best_yolov9_model.pt')
        shutil.copy(best_model_path, backup_path)
        print(f'Best model backup saved to {backup_path}')
    else:
        print("Warning: Best model weights not found!")
    
    # Generate confusion matrix and classification report if possible
    print('Test set predictions complete.')
    
    # Training results summary following other scripts' pattern
    results_dir = os.path.join(SAVE_DIR, 'yolov9_wheat_disease')
    if os.path.exists(results_dir):
        print(f"Training results and plots saved in: {results_dir}")
    
    # Print key metrics if available
    if hasattr(eval_results, 'results_dict'):
        metrics = eval_results.results_dict
        if 'metrics/mAP50(B)' in metrics:
            print(f"mAP@0.5: {metrics['metrics/mAP50(B)']:.4f}")
        if 'metrics/mAP50-95(B)' in metrics:
            print(f"mAP@0.5:0.95: {metrics['metrics/mAP50-95(B)']:.4f}")
    
    print("YOLOv9 training completed.")
    
    # Save training log to match other scripts
    import json
    training_log = {
        'model_type': 'YOLOv9',
        'epochs': EPOCHS,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'image_size': IMAGE_SIZE,
        'num_classes': NUM_CLASSES,
        'class_labels': class_labels,
        'early_stopping_patience': EARLY_STOPPING_PATIENCE
    }
    
    with open(os.path.join(SAVE_DIR, "yolov9_training_log.json"), 'w') as f:
        json.dump(training_log, f, indent=2)
    print('Training log saved.')
    
    return model

if __name__ == "__main__":
    model = main()
