In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
from sklearn.model_selection import train_test_split
import os
import shutil
from math import ceil
from PIL import Image
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
import json

In [2]:
# Base EfficientNet architecture details
base_model = [
    [1, 16, 1, 1, 3],
    [6, 24, 2, 2, 3],
    [6, 40, 2, 2, 5],
    [6, 80, 3, 2, 3],
    [6, 112, 3, 1, 5],
    [6, 192, 4, 2, 5],
    [6, 320, 1, 1, 3],
]

phi_values = {
    "b0": (0, 224, 0.2),
    "b1": (0.5, 240, 0.2),
    "b2": (1, 260, 0.3),
    "b3": (2, 300, 0.3),
    "b4": (3, 380, 0.4),
    "b5": (4, 456, 0.4),
    "b6": (5, 528, 0.5),
    "b7": (6, 600, 0.5),
}


In [3]:
# EfficientNet-related building blocks
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1):
        super(CNNBlock, self).__init__()
        self.cnn = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.silu = nn.SiLU()

    def forward(self, x):
        return self.silu(self.bn(self.cnn(x)))


class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels, reduced_dim):
        super(SqueezeExcitation, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, reduced_dim, 1),
            nn.SiLU(),
            nn.Conv2d(reduced_dim, in_channels, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return x * self.se(x)

In [4]:
class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM, self).__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False),
            nn.Sigmoid(),
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        ca = self.channel_attention(x)
        x = x * ca

        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        sa = torch.cat([max_pool, avg_pool], dim=1)
        sa = self.spatial_attention(sa)
        x = x * sa

        return x

In [5]:
class InvertedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, expand_ratio, reduction=4, survival_prob=0.8):
        super(InvertedResidualBlock, self).__init__()
        self.survival_prob = survival_prob
        self.use_residual = in_channels == out_channels and stride == 1
        self.hidden_dim = in_channels * expand_ratio
        self.expand = expand_ratio != 1

        if self.expand:
            self.expand_conv = CNNBlock(in_channels, self.hidden_dim, kernel_size=3, stride=1, padding=1)

        self.conv = nn.Sequential(
            CNNBlock(self.hidden_dim, self.hidden_dim, kernel_size, stride, padding, groups=self.hidden_dim),
            SqueezeExcitation(self.hidden_dim, reduced_dim=int(self.hidden_dim / reduction)),
            nn.Conv2d(self.hidden_dim, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.cbam = CBAM(out_channels)

    def stochastic_depth(self, x):
        if not self.training:
            return x
        binary_tensor = (
            torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
        )
        return torch.div(x, self.survival_prob) * binary_tensor

    def forward(self, inputs):
        x = self.expand_conv(inputs) if self.expand else inputs
        x = self.conv(x)
        x = self.cbam(x)
        if self.use_residual:
            return self.stochastic_depth(x) + inputs
        else:
            return x

In [6]:
class EfficientNet(nn.Module):
    def __init__(self, version, num_classes):
        super(EfficientNet, self).__init__()
        width_factor, depth_factor, dropout_rate = self.calculate_factors(version)
        last_channels = ceil(1280 * width_factor)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.features = self.create_features(width_factor, depth_factor, last_channels)
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(last_channels, num_classes),
        )

    def calculate_factors(self, version, alpha=1.2, beta=1.1):
        phi, res, drop_rate = phi_values[version]
        depth_factor = alpha**phi
        width_factor = beta**phi
        return width_factor, depth_factor, drop_rate

    def create_features(self, width_factor, depth_factor, last_channels):
        channels = int(32 * width_factor)
        features = [CNNBlock(3, channels, 3, stride=2, padding=1)]
        in_channels = channels

        for expand_ratio, channels, repeats, stride, kernel_size in base_model:
            out_channels = 4 * ceil(int(channels * width_factor) / 4)
            layers_repeats = ceil(repeats * depth_factor)

            for layer in range(layers_repeats):
                features.append(
                    InvertedResidualBlock(
                        in_channels,
                        out_channels,
                        expand_ratio=expand_ratio,
                        stride=stride if layer == 0 else 1,
                        kernel_size=kernel_size,
                        padding=kernel_size // 2,
                    )
                )
                in_channels = out_channels

        features.append(
            CNNBlock(in_channels, last_channels, kernel_size=1, stride=1, padding=0)
        )

        return nn.Sequential(*features)

    def forward(self, x):
        x = self.pool(self.features(x))
        return self.classifier(x.view(x.shape[0], -1))

In [7]:
class CropDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.crop_to_idx = {}  # Crop-to-index mapping
        self.disease_to_idx = {}


        

        # Load image paths and labels
        for crop_idx, crop_class in enumerate(os.listdir(root_dir)):  
            crop_path = os.path.join(root_dir, crop_class)
            if os.path.isdir(crop_path):
                for disease_idx, disease_class in enumerate(os.listdir(crop_path)):
                    disease_path = os.path.join(crop_path, disease_class)
                    if os.path.isdir(disease_path):
                        for image_name in os.listdir(disease_path):
                            image_path = os.path.join(disease_path, image_name)
                            self.image_paths.append(image_path)
                            self.labels.append((crop_idx, disease_idx))

        print(self.labels)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        crop_label, disease_label = self.labels[idx]

        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, (crop_label, disease_label)
    
    # def _load_image(self, img_path):
    #     from PIL import Image
    #     return Image.open(img_path).convert("RGB")

In [8]:
def process_single_image(image_path):
    """Process a single image and ensure it's valid."""
    try:
        with Image.open(image_path) as img:
            img.verify()
        with Image.open(image_path) as img:
            img = img.convert('RGB')
            img.save(image_path)
        return True
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return False


def remove_corrupted_images(directory):
    """Remove corrupted images and ensure all images are in RGB format."""
    removed_count = 0
    processed_count = 0

    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                file_path = os.path.join(root, file)
                try:
                    if process_single_image(file_path):
                        processed_count += 1
                    else:
                        os.remove(file_path)
                        removed_count += 1
                except Exception as e:
                    print(f"Failed to process {file_path}: {e}")
                    try:
                        os.remove(file_path)
                        removed_count += 1
                    except:
                        print(f"Failed to remove corrupted file: {file_path}")

    print(f"Processed {processed_count} images")
    print(f"Removed {removed_count} corrupted images")

In [9]:
def create_train_test_split(data_dir, output_dir, test_size=0.2):
    """Create train-test split with enhanced error handling."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    train_dir = os.path.join(output_dir, "train")
    test_dir = os.path.join(output_dir, "test")

    for crop_type in os.listdir(data_dir):
        crop_path = os.path.join(data_dir, crop_type)
        if not os.path.isdir(crop_path):
            continue

        for disease in os.listdir(crop_path):
            try:
                disease_dir = os.path.join(crop_path, disease)
                if not os.path.isdir(disease_dir):
                    continue

                os.makedirs(os.path.join(train_dir, crop_type, disease), exist_ok=True)
                os.makedirs(os.path.join(test_dir, crop_type, disease), exist_ok=True)

                valid_images = []
                for img in os.listdir(disease_dir):
                    if img.lower().endswith(('.jpg', '.jpeg', '.png')):
                        img_path = os.path.join(disease_dir, img)
                        try:
                            with Image.open(img_path) as im:
                                im = im.convert('RGB')
                                valid_images.append(img_path)
                        except Exception as e:
                            print(f"Skipping invalid image: {img_path}, Error: {e}")

                if not valid_images:
                    print(f"No valid images found in {disease_dir}")
                    continue

                train_images, test_images = train_test_split(valid_images, test_size=test_size, random_state=42)

                for img in train_images:
                    try:
                        shutil.copy2(img, os.path.join(train_dir, crop_type, disease))
                    except Exception as e:
                        print(f"Error copying {img}: {e}")

                for img in test_images:
                    try:
                        shutil.copy2(img, os.path.join(test_dir, crop_type, disease))
                    except Exception as e:
                        print(f"Error copying {img}: {e}")

            except Exception as e:
                print(f"Error processing disease {disease}: {e}")
                continue

In [10]:
# def train_efficientnet(data_dir, version="b3", batch_size=16, epochs=10):
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     print(f"Using device: {device}")

#     transform = transforms.Compose([
#         transforms.Resize((phi_values[version][1], phi_values[version][1])),  # Smaller input size
#         transforms.RandomHorizontalFlip(),
#         transforms.RandomVerticalFlip(),
#         transforms.RandomRotation(30),
#         transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
#         transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
#     ])

#     try:
#         # Load the dataset
#         train_dataset = CropDiseaseDataset(root_dir=os.path.join(data_dir, "train"), transform=transform)
#         test_dataset = CropDiseaseDataset(root_dir=os.path.join(data_dir, "test"), transform=transform)

#         if len(train_dataset) == 0 or len(test_dataset) == 0:
#             raise ValueError("Empty datasets detected")

#         print(f"Found {len(train_dataset)} training images and {len(test_dataset)} test images")

#         # Save crop and disease mappings
#         with open('crop_to_idx.json', 'w') as f:
#             json.dump(train_dataset.crop_to_idx, f)
#         with open('disease_to_idx.json', 'w') as f:
#             json.dump(train_dataset.disease_to_idx, f)
#         with open('idx_to_crop.json', 'w') as f:
#             json.dump(train_dataset.idx_to_crop, f)
#         with open('idx_to_disease.json', 'w') as f:
#             json.dump(train_dataset.idx_to_disease, f)

#         # Handle class imbalance
#         class_counts = torch.bincount(torch.tensor([disease_idx for _, (_, disease_idx) in train_dataset.samples]))
#         class_weights = 1. / class_counts
#         samples_weights = class_weights[[disease_idx for _, (_, disease_idx) in train_dataset.samples]]
#         sampler = WeightedRandomSampler(samples_weights, len(samples_weights))

#         train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True)
#         test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

#         # Set num_classes to the number of unique diseases
#         num_classes = len(train_dataset.disease_to_idx)

#         # Load pre-trained EfficientNet
#         model = EfficientNet(version=version, num_classes=num_classes)
#         model.classifier = nn.Sequential(
#             nn.Dropout(0.5),
#             nn.Linear(model.classifier[1].in_features, num_classes),  # Predict diseases
#         )
#         model = model.to(device)

#         criterion = nn.CrossEntropyLoss()
#         optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
#         scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

#         scaler = torch.amp.GradScaler('cuda')

#         for epoch in range(epochs):
#             model.train()
#             running_loss = 0.0
#             for i, (inputs, (crop_labels, disease_labels)) in enumerate(train_loader):
#                 inputs, disease_labels = inputs.to(device), disease_labels.to(device)

#                 optimizer.zero_grad()

#                 with torch.amp.autocast('cuda'):
#                     outputs = model(inputs)
#                     loss = criterion(outputs, disease_labels)

#                 scaler.scale(loss).backward()
#                 scaler.step(optimizer)
#                 scaler.update()

#                 running_loss += loss.item()

#                 if i % 10 == 0:  # Print more frequently
#                     print(f"Epoch [{epoch + 1}/{epochs}], Batch [{i}/{len(train_loader)}], Loss: {loss.item():.4f}")

#             epoch_loss = running_loss / len(train_loader)
#             print(f"Epoch [{epoch + 1}/{epochs}], Average Loss: {epoch_loss:.4f}")
#             scheduler.step()

#         model.eval()
#         correct = 0
#         total = 0
#         with torch.no_grad():
#             for inputs, (crop_labels, disease_labels) in test_loader:
#                 inputs, disease_labels = inputs.to(device), disease_labels.to(device)
#                 outputs = model(inputs)
#                 _, predicted = torch.max(outputs.data, 1)
#                 total += disease_labels.size(0)
#                 correct += (predicted == disease_labels).sum().item()

#         accuracy = 100 * correct / total
#         print(f"Test Accuracy: {accuracy:.2f}%")

#         # Save the model
#         torch.save(model.state_dict(), 'efficientnet_cbam_model.pth')

#     except Exception as e:
#         print(f"Training failed: {e}")
#         raise

In [11]:
def train_efficientnet(data_dir, version="b1", batch_size=16, epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    transform = transforms.Compose([
        transforms.Resize((phi_values[version][1], phi_values[version][1])),  # Smaller input size
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    try:
        # Load the dataset
        train_dataset = CropDiseaseDataset(root_dir=os.path.join(data_dir, "train"), transform=transform)
        test_dataset = CropDiseaseDataset(root_dir=os.path.join(data_dir, "test"), transform=transform)

        if len(train_dataset) == 0 or len(test_dataset) == 0:
            raise ValueError("Empty datasets detected")

        print(f"Found {len(train_dataset)} training images and {len(test_dataset)} test images")

        # Save crop and disease mappings
        with open('crop_to_idx.json', 'w') as f:
            json.dump(train_dataset, f,indent =4)
        with open('disease_to_idx.json', 'w') as f:
            json.dump(train_dataset.disease_to_idx, f)
        with open('idx_to_crop.json', 'w') as f:
            json.dump(train_dataset.idx_to_crop, f)
        with open('idx_to_disease.json', 'w') as f:
            json.dump(train_dataset.idx_to_disease, f)

        # Handle class imbalance
        class_counts = torch.bincount(torch.tensor([disease_idx for _, (_, disease_idx) in train_dataset.samples]))
        class_weights = 1. / class_counts
        samples_weights = class_weights[[disease_idx for _, (_, disease_idx) in train_dataset.samples]]
        sampler = WeightedRandomSampler(samples_weights, len(samples_weights))

        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        # Load pre-trained EfficientNet
        num_classes = len(train_dataset.disease_to_idx)  # Set num_classes based on the dataset
        model = EfficientNet(version=version, num_classes=num_classes)
        model.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(model.classifier[1].in_features, num_classes),  # Predict diseases
        )
        model = model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
        scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

        scaler = torch.amp.GradScaler(device_type'cuda')

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            for i, (inputs, (crop_labels, disease_labels)) in enumerate(train_loader):
                inputs, disease_labels = inputs.to(device), disease_labels.to(device)

                optimizer.zero_grad()

                with torch.amp.autocast(device_type='cuda'):
                    outputs = model(inputs)
                    loss = criterion(outputs, disease_labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                running_loss += loss.item()

                if i % 10 == 0:  # Print more frequently
                    print(f"Epoch [{epoch + 1}/{epochs}], Batch [{i}/{len(train_loader)}], Loss: {loss.item():.4f}")

            epoch_loss = running_loss / len(train_loader)
            print(f"Epoch [{epoch + 1}/{epochs}], Average Loss: {epoch_loss:.4f}")
            scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, (crop_labels, disease_labels) in test_loader:
                inputs, disease_labels = inputs.to(device), disease_labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += disease_labels.size(0)
                correct += (predicted == disease_labels).sum().item()

        accuracy = 100 * correct / total
        print(f"Test Accuracy: {accuracy:.2f}%")

        # Save the model
        torch.save(model.state_dict(), 'efficientnet_cbam_model.pth')

    except Exception as e:
        print(f"Training failed: {e}")
        raise

In [12]:
def predict(image_path, model_path='efficientnet_cbam_model.pth', class_to_idx_path='class_to_idx.json', version="b3"):
    """Perform inference on a single image."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load class-to-index mapping
    with open(class_to_idx_path, 'r') as f:
        class_to_idx = json.load(f)
    idx_to_class = {v: k for k, v in class_to_idx.items()}

    # Load the model
    model = EfficientNet(version=version, num_classes=len(class_to_idx)).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Preprocess the image
    transform = transforms.Compose([
        transforms.Resize((phi_values[version][1], phi_values[version][1])),  # Adjust based on your model's input size
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    # Perform inference
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
        predicted_class = idx_to_class[predicted.item()]

    return predicted_class

In [13]:
# Main execution
raw_data_dir = "Raw_Data"
processed_data_dir = "Processed_Data"

# # Preprocess data
# remove_corrupted_images(raw_data_dir)
# create_train_test_split(raw_data_dir, processed_data_dir)
# remove_corrupted_images(processed_data_dir)

#Train the model
train_efficientnet(processed_data_dir)

Using device: cuda
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 

TypeError: Object of type CropDiseaseDataset is not JSON serializable

In [38]:
# import os
# import json

# def create_class_to_index_json(data_dir, output_file="class_to_index.json"):
#     train_dir = os.path.join(data_dir, "train")  # Use train set for labels
#     class_names = sorted(os.listdir(train_dir))  # Sort for consistency
#     class_to_index = {class_name: idx for idx, class_name in enumerate(class_names)}

#     with open(output_file, "w") as f:
#         json.dump(class_to_index, f, indent=4)

#     print(f"Class-to-index mapping saved to {output_file}")

# # Example usage:
# create_class_to_index_json("Processed_Data")

In [23]:
#Example inference
# image_path = "healthy3_.jpg"
# predicted_class = predict(image_path)
# print(f"Predicted class: {predicted_class}")

  model.load_state_dict(torch.load(model_path, map_location=device))


Predicted class: Cassava


In [42]:
#Clearing cache

# import torch
# import gc

# torch.cuda.empty_cache()  # Clears unused memory from the cache
# torch.cuda.ipc_collect()
# gc.collect()  # Force garbage collection
# torch.cuda.empty_cache()