In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet
import os
from tqdm import tqdm
import json
from PIL import Image

### CBAM (Attention Layer)

In [40]:
# Define CBAM Layer
class CBAMLayer(nn.Module):
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(CBAMLayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )
        self.spatial_conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
        self.spatial_bn = nn.BatchNorm2d(1)
        self.spatial_sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.avg_pool(x).view(b, c)
        max_out = self.max_pool(x).view(b, c)
        channel_att = self.fc(avg_out) + self.fc(max_out)
        channel_att = channel_att.view(b, c, 1, 1)
        x = x * channel_att.expand_as(x)
        
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_att = self.spatial_conv(torch.cat([avg_out, max_out], dim=1))
        spatial_att = self.spatial_bn(spatial_att)
        spatial_att = self.spatial_sigmoid(spatial_att)
        return x * spatial_att


### Efficient Net 

In [41]:
# Define EfficientNet with CBAM
class EfficientNetCBAM(nn.Module):
    def __init__(self, version='b0', num_classes=10):
        super(EfficientNetCBAM, self).__init__()
        self.efficientnet = EfficientNet.from_pretrained(f'efficientnet-{version}')
        
        # Adding CBAM to specific layers
        self.cbam1 = CBAMLayer(in_channels=24)  # Example early layer
        self.cbam2 = CBAMLayer(in_channels=112) # Deeper layer
        
        num_ftrs = self.efficientnet._fc.in_features
        self.efficientnet._fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        x = self.efficientnet.extract_features(x)
        
        if x.shape[1] == 24:
            x = self.cbam1(x)
        if x.shape[1] == 112:
            x = self.cbam2(x)
            
        x = self.efficientnet._avg_pooling(x)
        x = x.flatten(start_dim=1)
        x = self.efficientnet._dropout(x)
        x = self.efficientnet._fc(x)
        return x

### Validate Model

In [42]:
# Model Validation Function
def validate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    val_loss /= len(val_loader)
    val_acc = 100 * correct / total
    return val_loss, val_acc

### Train the Model

In [43]:
# Model Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    scaler = torch.amp.GradScaler('cuda')
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        
        train_acc = 100 * correct / total
        val_loss, val_acc = validate_model(model, val_loader, criterion, device)
        
        print(f'Epoch [{epoch+1}/{num_epochs}] | Train Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'efficientnet_cbam_model.pth')
            print("Model saved!")

In [44]:
# Dataset and DataLoader Setup
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset_path = 'Processed_Data/train'
train_dataset = datasets.ImageFolder(os.path.join(dataset_path), transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = datasets.ImageFolder(os.path.join('Processed_Data/test'), transform=data_transforms)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Model Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(train_dataset.classes)
model = EfficientNetCBAM(version='b3', num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Train the Model
#train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10)

# Load Best Model & Perform Prediction
model.load_state_dict(torch.load('efficientnet_cbam_model.pth'))
model.eval()

Loaded pretrained weights for efficientnet-b3


  model.load_state_dict(torch.load('efficientnet_cbam_model.pth'))


EfficientNetCBAM(
  (efficientnet): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 40, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          40, 40, kernel_size=(3, 3), stride=[1, 1], groups=40, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          40, 10, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          10, 40, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStati

### Predict Function

In [49]:
# Load class labels
with open("class_names.json", "r") as f:
    class_labels = json.load(f)

def predict(image_path, crop_name, model, device):
    """
    Predicts the disease of the given image, considering only diseases from the specified crop.
    
    :param image_path: Path to the image file
    :param crop_name: The crop to filter diseases from (e.g., "Cashew")
    :param model: The trained model
    :param device: The device (CPU/GPU) for computation
    :return: Predicted crop name, disease name, and confidence score
    """

    # Define image transformations (same as training)
    transform = transforms.Compose([
        transforms.Resize((300, 300)), # Resize to EfficientNet-B3 input size
        transforms.CenterCrop(300),
        transforms.ToTensor(),          # Convert image to tensor
        transforms.Normalize(           # Normalize using ImageNet stats
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
    ])

    # Load and preprocess the image
    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        raise ValueError(f"Error loading image: {e}")

    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(device)

    # Move model to device
    model.load_state_dict(torch.load("efficientnet_cbam_model.pth", map_location=device))
    model.to(device)
    model.eval() 

    # Load class names
    with open("class_names.json", "r") as f:
        class_names = json.load(f)

    # Filter disease classes for the given crop
    crop_classes = [cls for cls in class_names if cls.startswith(crop_name)]
    if not crop_classes:
        raise ValueError(f"No diseases found for crop: {crop_name}")

    # Create index mapping for this crop
    crop_indices = [class_names.index(cls) for cls in crop_classes]

    # Run inference
    with torch.no_grad():
        output = model(image)

    # Extract relevant disease classes
    crop_output = output[:, crop_indices]  # Select only disease indices related to crop
    predicted_idx = torch.argmax(crop_output, dim=1).item()

    # Map back to disease name
    predicted_disease = crop_classes[predicted_idx]
    all_probs = torch.nn.functional.softmax(output, dim=1)  # Apply softmax over all classes
    filtered_probs = all_probs[:, crop_indices]  # Select only disease classes of the crop
    confidence = filtered_probs[0][predicted_idx].item()

    # Extract only the disease name (removing crop prefix)
    disease_name = predicted_disease.split("_", 1)[1]

    print("Available Classes:", class_names)
    print("Filtered Classes for Crop:", crop_classes)

    print(f"✅ Predicted Crop: {crop_name}")
    print(f"✅ Predicted Disease: {disease_name} (Confidence: {confidence:.2f})")

    return crop_name, disease_name, confidence

In [54]:
# ==========================
# SAMPLE PREDICTION
# ==========================

# Define crop and image path
crop_name = "Cashew"  # Example crop input
image_path = "Sample Predict/cahew_anthracnose.jpg"

# Run prediction
predict(image_path, crop_name, model, device)

  model.load_state_dict(torch.load("efficientnet_cbam_model.pth", map_location=device))


Available Classes: ['Cashew_anthracnose', 'Cashew_gumosis', 'Cashew_healthy', 'Cashew_leaf miner', 'Cashew_red rust', 'Cassava_bacterial blight', 'Cassava_brown spot', 'Cassava_green mite', 'Cassava_healthy', 'Cassava_mosaic', 'Maize_fall armyworm', 'Maize_grasshoper', 'Maize_healthy', 'Maize_leaf beetle', 'Maize_leaf blight', 'Maize_leaf spot', 'Maize_streak virus', 'Tomato_healthy', 'Tomato_leaf blight', 'Tomato_leaf curl', 'Tomato_septoria leaf spot', 'Tomato_verticulium wilt']
Filtered Classes for Crop: ['Cashew_anthracnose', 'Cashew_gumosis', 'Cashew_healthy', 'Cashew_leaf miner', 'Cashew_red rust']
✅ Predicted Crop: Cashew
✅ Predicted Disease: anthracnose (Confidence: 0.19)


('Cashew', 'anthracnose', 0.1937492936849594)