In [25]:
from utils.load_dataset import PlantVillageDataset
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.preprocessing import preprocessing_img
import torchvision.models  as models
from torchvision import transforms
import torch.optim as optim
from test import LinearHeadModel

In [35]:
# Load the dataset
PATH = './Plant_leave_diseases_dataset_without_augmentation'
transform = transforms.Compose([
    preprocessing_img,
    transforms.ToTensor(),
])
training_data = PlantVillageDataset(PATH, img_mode="LAB", train=True, transform=transform)
test_data = PlantVillageDataset(PATH, img_mode="LAB", train=False, transform=transform)

In [36]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [30]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

## Import pre-trained models

In [4]:
# ResNet50
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')
for param in resnet50.parameters():
    param.requires_grad = False

# DenseNet201
densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
for param in densenet.parameters():
    param.requires_grad = False

# EfficientNetB0
efficientnet_b0 = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
for param in efficientnet_b0.parameters():
    param.requires_grad = False

# EfficientNetB3
efficientnet_b3 = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
for param in efficientnet_b3.parameters():
    param.requires_grad = False

## Import heads for each pre-trained models

In [5]:
# Load model from file
def load_model_from_file(model_path, input_dim, output_dim):
    model = LinearHeadModel(input_dim, output_dim)
    model.load_state_dict(torch.load(model_path))
    return model

In [6]:
input_dim = 224 * 224 * 3
num_classes = 38
output_dim = num_classes + 1

In [11]:
resnet50_head = load_model_from_file('models/resnet50.pth', input_dim, output_dim)
densenet201_head = load_model_from_file('models/densenet201.pth', input_dim, output_dim)
efficientnet_b0_head = load_model_from_file('models/efficientnet_b0.pth', input_dim, output_dim)
efficientnet_b3_head = load_model_from_file('models/efficientnet_b3.pth', input_dim, output_dim)

  model.load_state_dict(torch.load(model_path))


In [58]:
# Modify head of pre-trained models
resnet50.fc = resnet50_head
densenet.fc = densenet201_head
efficientnet_b0.fc = efficientnet_b0_head
efficientnet_b3.fc = efficientnet_b3_head

In [59]:
resnet50.to(device)
densenet.to(device)
efficientnet_b0.to(device)
efficientnet_b3.to(device)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
            (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActiv

In [14]:
models_list = [resnet50, densenet, efficientnet_b0, efficientnet_b3]

In [52]:
class WeightedEnsemble(nn.Module):
    def __init__(self, models, num_classes):
        super(WeightedEnsemble, self).__init__()
        self.models = models  # List of models
        self.num_models = len(models)  # Number of models in the ensemble
        self.weights = nn.Parameter(torch.ones(self.num_models) / self.num_models)  # Trainable weights
        self.num_classes = num_classes  # Number of classes

    def forward(self, input_tensor):
        all_preds = []
        for model in self.models:
            model.eval()  # Ensure the model is in evaluation mode
            with torch.no_grad():
                output = model(input_tensor)  # Forward pass through the model
                prob = nn.functional.softmax(output, dim=1)  # Convert logits to probabilities
                all_preds.append(prob)

        # Stack predictions along a new dimension
        all_preds = torch.stack(all_preds)  # Shape: (num_models, batch_size, num_classes)

        # Apply softmax to normalize weights and combine predictions
        normalized_weights = nn.functional.softmax(self.weights, dim=0)  # Shape: (num_models,)
        weighted_preds = torch.sum(all_preds * normalized_weights.view(-1, 1, 1), dim=0)  # Weighted sum
        print('weighted_preds', weighted_preds.shape)
        return weighted_preds  # Shape: (batch_size, num_classes)


In [54]:
def train_ensemble_weights(ensemble, train_loader, criterion, optimizer, epochs=10, device='mps'):
    ensemble.to(device)
    
    for epoch in range(epochs):
        ensemble.train()  # Set ensemble to training mode
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()

            # Forward pass through the ensemble
            print('inputs',inputs.shape)
            outputs = ensemble(inputs)
            print("Ensemble output shape:", outputs.shape)  # Debug output shape

            # Compute the loss
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

    print("Ensemble weight training complete!")


In [61]:
ensemble = WeightedEnsemble(models_list, output_dim)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([ensemble.weights], lr=0.01)

train_ensemble_weights(ensemble, train_dataloader, criterion, optimizer, epochs=10)

inputs torch.Size([64, 3, 224, 224])


ValueError: Input dimension should be at least 3