In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_5995.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/Leaf_scald (86).jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_5914.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_2467.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_2138.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_914.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_3199.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_2738.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_143.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_1933.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_3645.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/aug_0_890.jpg
/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG/Leaf scald/Leaf_scald (29).jpg
/kaggle/input/rice

In [3]:
##### import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import mobilenet_v2
import numpy as np
import cv2
from PIL import Image
import os
import json
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from sklearn.metrics import confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Dict, Optional
import warnings
import glob
import random
from collections import defaultdict
warnings.filterwarnings('ignore')

class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(DepthwiseSeparableConv2d, self).__init__()
        
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size, 
            stride=stride, padding=padding, groups=in_channels, bias=bias
        )
        
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
        
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class ModifiedInceptionModule(nn.Module):
    """
    Modified Inception (M-Inception) module with depth-wise separable convolutions
    Replaces standard convolutions with DSC to reduce parameters
    """
    def __init__(self, in_channels, branch1_out, branch2_reduce, branch2_out, 
                 branch3_reduce, branch3_out, branch4_out):
        super(ModifiedInceptionModule, self).__init__()
        
        # Branch 1: 1x1 conv
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, branch1_out, 1),
            nn.BatchNorm2d(branch1_out),
            nn.ReLU(inplace=True)
        )
        
        # Branch 2: 1x1 conv -> 3x3 DSC
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, branch2_reduce, 1),
            nn.BatchNorm2d(branch2_reduce),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv2d(branch2_reduce, branch2_out, 3, padding=1),
            nn.BatchNorm2d(branch2_out),
            nn.ReLU(inplace=True)
        )
        
        # Branch 3: 1x1 conv -> two 3x3 DSC (replacing 5x5)
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, branch3_reduce, 1),
            nn.BatchNorm2d(branch3_reduce),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv2d(branch3_reduce, branch3_reduce, 3, padding=1),
            nn.BatchNorm2d(branch3_reduce),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv2d(branch3_reduce, branch3_out, 3, padding=1),
            nn.BatchNorm2d(branch3_out),
            nn.ReLU(inplace=True)
        )
        
        # Branch 4: 3x3 max pool -> 1x1 conv
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, branch4_out, 1),
            nn.BatchNorm2d(branch4_out),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        
        # Concatenate along channel dimension
        outputs = torch.cat([branch1, branch2, branch3, branch4], 1)
        return outputs

class MobIncNet(nn.Module):
    """
    MobInc-Net: Lightweight network combining MobileNet with Modified Inception
    """
    def __init__(self, num_classes=12, pretrained=True):
        super(MobIncNet, self).__init__()
        
        # Load pre-trained MobileNetV2 backbone
        mobilenet = mobilenet_v2(pretrained=pretrained)
        
        # Extract features from MobileNet (remove classifier)
        self.backbone = nn.Sequential(*list(mobilenet.children())[:-1])
        
        # Get the number of features from MobileNet
        backbone_out_features = 1280  # MobileNetV2 output features
        
        # Modified Inception module with 512 filters
        self.m_inception = ModifiedInceptionModule(
            in_channels=backbone_out_features,
            branch1_out=128,      # 1x1 branch
            branch2_reduce=96,    # 3x3 reduce
            branch2_out=128,      # 3x3 output
            branch3_reduce=16,    # 5x5 reduce (using two 3x3)
            branch3_out=32,       # 5x5 output
            branch4_out=64        # pool proj
        )
        
        # Calculate total output channels from M-Inception
        inception_out_features = 128 + 128 + 32 + 64  # 352
        
        # Global Average Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(inception_out_features, num_classes)
        )
        
    def forward(self, x):
        # Extract features using MobileNet backbone
        x = self.backbone(x)
        
        # Apply Modified Inception module
        x = self.m_inception(x)
        
        # Global Average Pooling
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        
        # Classification
        x = self.classifier(x)
        return x

class FocalLoss(nn.Module):
    """
    Improved Focal Loss for multi-class classification
    Addresses class imbalance and focuses on hard examples
    """
    def __init__(self, alpha=1.0, gamma=2.0, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class RiceDiseaseDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            image = Image.open(image_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224), color='white')
            
        if self.transform:
            image = self.transform(image)
            
        return image, label

class RiceDatasetLoader:
    def __init__(self, dataset_path="/kaggle/input/rice-disease-dataset/Rice_Leaf_AUG"):
        self.dataset_path = dataset_path
        self.class_names = [
            'Bacterial Leaf Blight',
            'Brown Spot', 
            'Healthy Rice Leaf',
            'Leaf Blast',
            'Leaf scald',
            'Sheath Blight'
        ]
        self.class_to_idx = {name: idx for idx, name in enumerate(self.class_names)}
        
    def load_dataset(self, train_split=0.7, val_split=0.15, test_split=0.15):
        all_image_paths = []
        all_labels = []
        
        for class_name in self.class_names:
            class_path = os.path.join(self.dataset_path, class_name)
            if os.path.exists(class_path):
                image_files = glob.glob(os.path.join(class_path, "*.*"))
                image_files = [f for f in image_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                
                class_label = self.class_to_idx[class_name]
                all_image_paths.extend(image_files)
                all_labels.extend([class_label] * len(image_files))
        
        data_by_class = defaultdict(list)
        for path, label in zip(all_image_paths, all_labels):
            data_by_class[label].append(path)
        
        train_paths, val_paths, test_paths = [], [], []
        train_labels, val_labels, test_labels = [], [], []
        
        for class_idx, paths in data_by_class.items():
            random.shuffle(paths)
            n_total = len(paths)
            n_train = int(n_total * train_split)
            n_val = int(n_total * val_split)
            
            train_paths.extend(paths[:n_train])
            val_paths.extend(paths[n_train:n_train+n_val])
            test_paths.extend(paths[n_train+n_val:])
            
            train_labels.extend([class_idx] * n_train)
            val_labels.extend([class_idx] * n_val)
            test_labels.extend([class_idx] * (n_total - n_train - n_val))
        
        return {
            'train': (train_paths, train_labels),
            'val': (val_paths, val_labels),
            'test': (test_paths, test_labels)
        }
    
    def create_dataloaders(self, batch_size=64):
        data_splits = self.load_dataset()
        
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=(-180, 180)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        val_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        train_dataset = RiceDiseaseDataset(
            data_splits['train'][0], data_splits['train'][1], train_transform
        )
        val_dataset = RiceDiseaseDataset(
            data_splits['val'][0], data_splits['val'][1], val_transform
        )
        test_dataset = RiceDiseaseDataset(
            data_splits['test'][0], data_splits['test'][1], val_transform
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        
        print(f"Train samples: {len(train_dataset)}")
        print(f"Val samples: {len(val_dataset)}")
        print(f"Test samples: {len(test_dataset)}")
        print(f"Classes: {self.class_names}")
        
        return train_loader, val_loader, test_loader

class DataAugmentation:
    """
    Data augmentation techniques for training
    """
    @staticmethod
    def get_train_transforms():
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=(-180, 180)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def get_val_transforms():
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

class TwoStageTrainer:
    """
    Two-stage transfer learning trainer
    """
    def __init__(self, model, device, num_classes):
        self.model = model.to(device)
        self.device = device
        self.num_classes = num_classes
        
    def stage_one_training(self, train_loader, val_loader, epochs=10, lr=5e-4):
        """
        Stage 1: Freeze backbone, train only new layers
        """
        print("Stage 1: Training new layers only...")
        
        # Freeze backbone parameters
        for param in self.model.backbone.parameters():
            param.requires_grad = False
            
        # Only optimize new layers
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=lr, weight_decay=1e-4
        )
        
        # Calculate class weights for focal loss
        class_weights = self._calculate_class_weights(train_loader)
        criterion = FocalLoss(gamma=2.0, weight=class_weights)
        
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        
        best_val_acc = 0.0
        train_losses, val_losses, val_accuracies = [], [], []
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss = 0.0
            
            for batch_idx, (data, targets) in enumerate(train_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(data)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                
                if batch_idx % 10 == 0:
                    print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')
            
            # Validation
            val_loss, val_acc = self._validate(val_loader, criterion)
            
            train_losses.append(train_loss / len(train_loader))
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)
            
            print(f'Epoch {epoch+1}/{epochs}: Train Loss: {train_loss/len(train_loader):.4f}, '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(self.model.state_dict(), 'mobinc_net_stage1.pth')
            
            scheduler.step()
            
        return train_losses, val_losses, val_accuracies
    
    def stage_two_training(self, train_loader, val_loader, epochs=10, lr=1e-4):
        """
        Stage 2: Fine-tune entire network
        """
        print("Stage 2: Fine-tuning entire network...")
        
        # Unfreeze all parameters
        for param in self.model.parameters():
            param.requires_grad = True
            
        # Lower learning rate for fine-tuning
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
        
        # Calculate class weights for focal loss
        class_weights = self._calculate_class_weights(train_loader)
        criterion = FocalLoss(gamma=2.0, weight=class_weights)
        
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        
        best_val_acc = 0.0
        train_losses, val_losses, val_accuracies = [], [], []
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss = 0.0
            
            for batch_idx, (data, targets) in enumerate(train_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(data)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                
                if batch_idx % 10 == 0:
                    print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')
            
            # Validation
            val_loss, val_acc = self._validate(val_loader, criterion)
            
            train_losses.append(train_loss / len(train_loader))
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)
            
            print(f'Epoch {epoch+1}/{epochs}: Train Loss: {train_loss/len(train_loader):.4f}, '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(self.model.state_dict(), 'mobinc_net_final.pth')
            
            scheduler.step()
            
        return train_losses, val_losses, val_accuracies
    
    def _validate(self, val_loader, criterion):
        """Validation function"""
        self.model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                outputs = self.model(data)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        val_acc = correct / total
        return val_loss / len(val_loader), val_acc
    
    def _calculate_class_weights(self, train_loader):
        """Calculate class weights for focal loss"""
        class_counts = torch.zeros(self.num_classes)
        
        for _, targets in train_loader:
            for target in targets:
                class_counts[target] += 1
        
        total_samples = class_counts.sum()
        class_weights = total_samples / (self.num_classes * class_counts)
        
        return class_weights.to(self.device)

class ModelEvaluator:
    """
    Comprehensive model evaluation with multiple metrics
    """
    def __init__(self, model, device, class_names):
        self.model = model
        self.device = device
        self.class_names = class_names
        
    def evaluate(self, test_loader):
        """Comprehensive evaluation"""
        self.model.eval()
        all_predictions = []
        all_targets = []
        all_probabilities = []
        
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                outputs = self.model(data)
                probabilities = F.softmax(outputs, dim=1)
                _, predicted = torch.max(outputs, 1)
                
                all_predictions.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_probabilities.extend(probabilities.cpu().numpy())
        
        # Calculate metrics
        accuracy = accuracy_score(all_targets, all_predictions)
        recall = recall_score(all_targets, all_predictions, average='macro')
        precision = precision_score(all_targets, all_predictions, average='macro')
        f1 = f1_score(all_targets, all_predictions, average='macro')
        
        # Specificity calculation for each class
        cm = confusion_matrix(all_targets, all_predictions)
        specificity_per_class = []
        
        for i in range(len(self.class_names)):
            tn = np.sum(cm) - (np.sum(cm[i, :]) + np.sum(cm[:, i]) - cm[i, i])
            fp = np.sum(cm[:, i]) - cm[i, i]
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            specificity_per_class.append(specificity)
        
        avg_specificity = np.mean(specificity_per_class)
        
        results = {
            'accuracy': accuracy,
            'recall': recall,
            'precision': precision,
            'f1_score': f1,
            'specificity': avg_specificity,
            'predictions': all_predictions,
            'targets': all_targets,
            'probabilities': all_probabilities,
            'confusion_matrix': cm
        }
        
        return results
    
    def plot_confusion_matrix(self, cm, save_path=None):
        """Plot confusion matrix"""
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=self.class_names,
                    yticklabels=self.class_names)
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_roc_curves(self, targets, probabilities, save_path=None):
        """Plot ROC curves for each class"""
        from sklearn.preprocessing import label_binarize
        from sklearn.metrics import roc_curve, auc
        
        # Binarize targets for multi-class ROC
        targets_binary = label_binarize(targets, classes=range(len(self.class_names)))
        probabilities = np.array(probabilities)
        
        plt.figure(figsize=(12, 8))
        
        for i in range(len(self.class_names)):
            fpr, tpr, _ = roc_curve(targets_binary[:, i], probabilities[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f'{self.class_names[i]} (AUC = {roc_auc:.2f})')
        
        plt.plot([0, 1], [0, 1], 'k--', label='Random')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curves for Rice Disease Classification')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

class MobIncNetSSD(nn.Module):
    """
    MobInc-Net based SSD for disease detection
    """
    def __init__(self, num_classes=12, num_boxes=[6, 6, 6, 6, 4, 4]):
        super(MobIncNetSSD, self).__init__()
        
        # Base network (MobInc-Net without classifier)
        self.base_net = MobIncNet(num_classes=num_classes)
        # Remove classifier for feature extraction
        self.base_net = nn.Sequential(*list(self.base_net.children())[:-1])
        
        # Auxiliary network for multi-scale detection
        self.aux_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(352, 256, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 512, 3, stride=2, padding=1),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(512, 128, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 256, 3, stride=2, padding=1),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(256, 128, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 256, 3),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(256, 64, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 128, 3),
                nn.ReLU(inplace=True)
            )
        ])
        
        # Prediction layers
        self.loc_layers = nn.ModuleList()
        self.conf_layers = nn.ModuleList()
        
        source_channels = [352, 512, 256, 256, 128]
        
        for i, (channels, boxes) in enumerate(zip(source_channels, num_boxes)):
            self.loc_layers.append(nn.Conv2d(channels, boxes * 4, 3, padding=1))
            self.conf_layers.append(nn.Conv2d(channels, boxes * num_classes, 3, padding=1))
    
    def forward(self, x):
        sources = []
        loc_preds = []
        conf_preds = []
        
        # Base network feature extraction
        x = self.base_net[0](x)  # backbone
        x = self.base_net[1](x)  # m_inception
        x = self.base_net[2](x)  # global_avg_pool
        sources.append(x)
        
        # Auxiliary network
        for aux_conv in self.aux_convs:
            x = aux_conv(x)
            sources.append(x)
        
        # Predictions
        for i, (source, loc_layer, conf_layer) in enumerate(zip(sources, self.loc_layers, self.conf_layers)):
            loc_pred = loc_layer(source)
            conf_pred = conf_layer(source)
            
            # Reshape predictions
            loc_pred = loc_pred.permute(0, 2, 3, 1).contiguous()
            conf_pred = conf_pred.permute(0, 2, 3, 1).contiguous()
            
            loc_preds.append(loc_pred.view(loc_pred.size(0), -1))
            conf_preds.append(conf_pred.view(conf_pred.size(0), -1))
        
        # Concatenate predictions
        loc = torch.cat(loc_preds, 1)
        conf = torch.cat(conf_preds, 1)
        
        return loc, conf

class RiceDiseaseRecognitionSystem:
    """
    Complete system for rice disease recognition and detection
    """
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.class_names = [
            'Bacterial Leaf Blight',
            'Brown Spot', 
            'Healthy Rice Leaf',
            'Leaf Blast',
            'Leaf scald',
            'Sheath Blight'
        ]
        self.num_classes = len(self.class_names)
        
        # Initialize models
        self.classification_model = MobIncNet(num_classes=self.num_classes)
        self.detection_model = MobIncNetSSD(num_classes=self.num_classes)
        
        print(f"Using device: {self.device}")
        print(f"Number of classes: {self.num_classes}")
        
    def train_classification_model(self, batch_size=64, stage1_epochs=10, stage2_epochs=10):
        train_loader = self._prepare_dataloader('train', batch_size)
        val_loader = self._prepare_dataloader('val', batch_size)
        
        trainer = TwoStageTrainer(self.classification_model, self.device, self.num_classes)
        
        stage1_results = trainer.stage_one_training(train_loader, val_loader, stage1_epochs)
        
        stage2_results = trainer.stage_two_training(train_loader, val_loader, stage2_epochs)
        
        return stage1_results, stage2_results
    
    def evaluate_model(self, model_path=None):
        if model_path:
            self.classification_model.load_state_dict(torch.load(model_path, map_location=self.device))
        
        test_loader = self._prepare_dataloader('test', batch_size=32)
        
        evaluator = ModelEvaluator(self.classification_model, self.device, self.class_names)
        results = evaluator.evaluate(test_loader)
        
        print("\n" + "="*50)
        print("MODEL EVALUATION RESULTS")
        print("="*50)
        print(f"Accuracy: {results['accuracy']:.4f}")
        print(f"Recall: {results['recall']:.4f}")
        print(f"Precision: {results['precision']:.4f}")
        print(f"F1-Score: {results['f1_score']:.4f}")
        print(f"Specificity: {results['specificity']:.4f}")
        
        evaluator.plot_confusion_matrix(results['confusion_matrix'])
        evaluator.plot_roc_curves(results['targets'], results['probabilities'])
        
        return results
    
    def predict_single_image(self, image_path, model_path=None):
        """
        Predict disease for a single image
        """
        if model_path:
            self.classification_model.load_state_dict(torch.load(model_path, map_location=self.device))
        
        self.classification_model.eval()
        
        # Load and preprocess image
        transform = DataAugmentation.get_val_transforms()
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.classification_model(image_tensor)
            probabilities = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
            predicted_class = self.class_names[predicted.item()]
            confidence = probabilities[0][predicted.item()].item()
            
        return predicted_class, confidence, probabilities[0].cpu().numpy()
    
    def _prepare_dataloader(self, data_type='train', batch_size=64):
        if not hasattr(self, 'data_loader'):
            self.data_loader = RiceDatasetLoader()
            
        if data_type == 'train':
            train_loader, _, _ = self.data_loader.create_dataloaders(batch_size)
            return train_loader
        elif data_type == 'val':
            _, val_loader, _ = self.data_loader.create_dataloaders(batch_size)
            return val_loader
        else:
            _, _, test_loader = self.data_loader.create_dataloaders(batch_size)
            return test_loader
    
    def get_model_parameters(self):
        """
        Get model parameter count and size
        """
        total_params = sum(p.numel() for p in self.classification_model.parameters())
        trainable_params = sum(p.numel() for p in self.classification_model.parameters() if p.requires_grad)
        
        # Model size estimation
        param_size = 0
        for param in self.classification_model.parameters():
            param_size += param.nelement() * param.element_size()
        
        buffer_size = 0
        for buffer in self.classification_model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        
        model_size_mb = (param_size + buffer_size) / 1024 / 1024
        
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Model size: {model_size_mb:.2f} MB")
        
        return total_params, trainable_params, model_size_mb

def demo_usage():
    print("MobInc-Net Rice Disease Recognition System Demo")
    print("="*50)
    
    system = RiceDiseaseRecognitionSystem()
    
    print("\nModel Architecture Information:")
    system.get_model_parameters()
    
    print("\nTraining the model:")
    stage1_results, stage2_results = system.train_classification_model(
        batch_size=64,
        stage1_epochs=10,
        stage2_epochs=10
    )
    
    print("\nEvaluating the model:")
    results = system.evaluate_model()
    
    return system, results

def train_and_evaluate_mobinc_net():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    data_loader = RiceDatasetLoader()
    train_loader, val_loader, test_loader = data_loader.create_dataloaders(batch_size=64)
    
    model = MobIncNet(num_classes=len(data_loader.class_names))
    model = model.to(device)
    
    trainer = TwoStageTrainer(model, device, len(data_loader.class_names))
    
    print("Starting Stage 1 Training...")
    stage1_results = trainer.stage_one_training(train_loader, val_loader, epochs=10)
    
    print("Starting Stage 2 Training...")
    stage2_results = trainer.stage_two_training(train_loader, val_loader, epochs=10)
    
    print("Evaluating model...")
    evaluator = ModelEvaluator(model, device, data_loader.class_names)
    results = evaluator.evaluate(test_loader)
    
    print(f"\nFinal Results:")
    print(f"Accuracy: {results['accuracy']:.4f}")
    print(f"Recall: {results['recall']:.4f}")
    print(f"Precision: {results['precision']:.4f}")
    print(f"F1-Score: {results['f1_score']:.4f}")
    print(f"Specificity: {results['specificity']:.4f}")
    
    return model, results

def compare_with_baseline_models():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    data_loader = RiceDatasetLoader()
    _, _, test_loader = data_loader.create_dataloaders(batch_size=32)
    
    comparison = ModelComparison(device)
    comparison.load_comparison_models(len(data_loader.class_names))
    
    print("Comparing models...")
    results = comparison.compare_models(test_loader)
    
    print("\nComparison Results:")
    for model_name, result in results.items():
        print(f"{model_name}: Acc={result['accuracy']:.4f}, Params={result['parameters']:,}")
    
    return results

class DatasetCreator:
    """
    Helper class to create datasets from image directories
    """
    @staticmethod
    def create_dataset_from_directory(root_dir, test_split=0.2, val_split=0.2):
        """
        Create train/val/test splits from directory structure:
        root_dir/
        ├── class1/
        │   ├── image1.jpg
        │   └── image2.jpg
        ├── class2/
        │   ├── image1.jpg
        │   └── image2.jpg
        └── ...
        """
        import random
        from collections import defaultdict
        
        # Collect all image paths and labels
        image_paths = []
        labels = []
        class_to_idx = {}
        
        for class_idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                class_to_idx[class_name] = class_idx
                
                for image_name in os.listdir(class_path):
                    if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        image_paths.append(os.path.join(class_path, image_name))
                        labels.append(class_idx)
        
        # Create stratified splits
        data_by_class = defaultdict(list)
        for path, label in zip(image_paths, labels):
            data_by_class[label].append(path)
        
        train_paths, val_paths, test_paths = [], [], []
        train_labels, val_labels, test_labels = [], [], []
        
        for class_idx, paths in data_by_class.items():
            random.shuffle(paths)
            n_total = len(paths)
            n_test = int(n_total * test_split)
            n_val = int(n_total * val_split)
            n_train = n_total - n_test - n_val
            
            train_paths.extend(paths[:n_train])
            val_paths.extend(paths[n_train:n_train+n_val])
            test_paths.extend(paths[n_train+n_val:])
            
            train_labels.extend([class_idx] * n_train)
            val_labels.extend([class_idx] * n_val)
            test_labels.extend([class_idx] * n_test)
        
        return {
            'train': (train_paths, train_labels),
            'val': (val_paths, val_labels),
            'test': (test_paths, test_labels),
            'class_to_idx': class_to_idx
        }

class ModelComparison:
    """
    Compare MobInc-Net with other lightweight models
    """
    def __init__(self, device):
        self.device = device
        self.models = {}
        
    def load_comparison_models(self, num_classes):
        """Load all comparison models mentioned in the paper"""
        from torchvision.models import mobilenet_v2, efficientnet_b0
        
        # MobileNetV2
        mobilenet = mobilenet_v2(pretrained=True)
        mobilenet.classifier = nn.Linear(mobilenet.last_channel, num_classes)
        self.models['MobileNetV2'] = mobilenet.to(self.device)
        
        # EfficientNet-B0
        efficientnet = efficientnet_b0(pretrained=True)
        efficientnet.classifier = nn.Linear(efficientnet.classifier.in_features, num_classes)
        self.models['EfficientNet-B0'] = efficientnet.to(self.device)
        
        # Add MobInc-Net
        self.models['MobInc-Net'] = MobIncNet(num_classes=num_classes).to(self.device)
        
    def compare_models(self, test_loader):
        """Compare performance of all models"""
        results = {}
        
        for model_name, model in self.models.items():
            print(f"\nEvaluating {model_name}...")
            
            model.eval()
            correct = 0
            total = 0
            
            with torch.no_grad():
                for data, targets in test_loader:
                    data, targets = data.to(self.device), targets.to(self.device)
                    outputs = model(data)
                    _, predicted = torch.max(outputs, 1)
                    total += targets.size(0)
                    correct += (predicted == targets).sum().item()
            
            accuracy = correct / total
            
            # Count parameters
            total_params = sum(p.numel() for p in model.parameters())
            
            results[model_name] = {
                'accuracy': accuracy,
                'parameters': total_params
            }
            
            print(f"{model_name}: Accuracy = {accuracy:.4f}, Parameters = {total_params:,}")
        
        return results

class PlantVillageDataLoader:
    """
    Data loader for PlantVillage dataset
    """
    @staticmethod
    def prepare_plantvillage_data(dataset_path, batch_size=64):
        """
        Prepare PlantVillage dataset loaders
        Assumes dataset structure: dataset_path/class_name/images
        """
        # Data augmentation transforms
        train_transform = DataAugmentation.get_train_transforms()
        val_transform = DataAugmentation.get_val_transforms()
        
        # Create datasets (placeholder implementation)
        # You would use torchvision.datasets.ImageFolder for actual implementation
        
        print("PlantVillage dataset preparation completed")
        print(f"Training samples: 43,444 (80%)")
        print(f"Validation samples: 10,861 (20%)")
        
        return None, None  # Return actual dataloaders in real implementation


def main():
    model, results = train_and_evaluate_mobinc_net()
    
if __name__ == "__main__":
    main()

Using device: cuda
Train samples: 2678
Val samples: 571
Test samples: 580
Classes: ['Bacterial Leaf Blight', 'Brown Spot', 'Healthy Rice Leaf', 'Leaf Blast', 'Leaf scald', 'Sheath Blight']
Starting Stage 1 Training...
Stage 1: Training new layers only...
Epoch 1/10, Batch 0, Loss: 1.2621
Epoch 1/10, Batch 10, Loss: 0.9196
Epoch 1/10, Batch 20, Loss: 0.6608
Epoch 1/10, Batch 30, Loss: 0.7501
Epoch 1/10, Batch 40, Loss: 0.7016
Epoch 1/10: Train Loss: 0.7871, Val Loss: 0.6585, Val Acc: 0.5814
Epoch 2/10, Batch 0, Loss: 0.5187
Epoch 2/10, Batch 10, Loss: 0.6950
Epoch 2/10, Batch 20, Loss: 0.6736
Epoch 2/10, Batch 30, Loss: 0.5534
Epoch 2/10, Batch 40, Loss: 0.5412
Epoch 2/10: Train Loss: 0.5499, Val Loss: 0.6463, Val Acc: 0.5639
Epoch 3/10, Batch 0, Loss: 0.4718
Epoch 3/10, Batch 10, Loss: 0.4738
Epoch 3/10, Batch 20, Loss: 0.7075
Epoch 3/10, Batch 30, Loss: 0.5519
Epoch 3/10, Batch 40, Loss: 0.4888
Epoch 3/10: Train Loss: 0.4971, Val Loss: 0.4993, Val Acc: 0.6392
Epoch 4/10, Batch 0, Loss