In [57]:
from utils.load_dataset import PlantVillageDataset
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.optim as optim 
import torch.nn as nn
from utils.preprocessing import preprocessing_img
import torchvision.models  as models
from torchvision import transforms
from test import LinearHeadModel

In [58]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [60]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

## Import pre-trained models

In [61]:
# ResNet50
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')
for param in resnet50.parameters():
    param.requires_grad = False

# DenseNet201
densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
for param in densenet.parameters():
    param.requires_grad = False

# EfficientNetB0
efficientnet_b0 = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
for param in efficientnet_b0.parameters():
    param.requires_grad = False

# EfficientNetB3
efficientnet_b3 = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
for param in efficientnet_b3.parameters():
    param.requires_grad = False

## Import heads for each pre-trained models

In [62]:
# Load model from file
def load_model_from_file(model_path, input_dim, output_dim):
    model = LinearHeadModel(input_dim, output_dim)
    model.load_state_dict(torch.load(model_path))
    return model

In [63]:
input_dim = 224 * 224 * 3
num_classes = 38
output_dim = num_classes + 1

In [94]:
resnet50_head = load_model_from_file('models/resnet50.pth', input_dim, output_dim)
densenet201_head = load_model_from_file('models/densenet201.pth', input_dim, output_dim)
efficientnet_b0_head = load_model_from_file('models/efficientnet_b0.pth', input_dim, output_dim)
efficientnet_b3_head = load_model_from_file('models/efficientnet_b3.pth', input_dim, output_dim)

  model.load_state_dict(torch.load(model_path))


In [92]:
resnet50_head

LinearHeadModel(
  (batch_norm): BatchNorm1d(150528, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear_layers): Sequential(
    (0): Linear(in_features=150528, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.25, inplace=False)
    (6): Linear(in_features=128, out_features=64, bias=True)
    (7): Dropout(p=0.25, inplace=False)
    (8): Linear(in_features=64, out_features=39, bias=True)
    (9): Softmax(dim=1)
  )
)

In [65]:
# Modify head of pre-trained models
resnet50.fc = resnet50_head.to(device)
densenet.fc = densenet201_head.to(device)
efficientnet_b0.fc = efficientnet_b0_head.to(device)
efficientnet_b3.fc = efficientnet_b3_head.to(device)

In [None]:
resnet50.to(device)
densenet.to(device)
efficientnet_b0.to(device)
efficientnet_b3.to(device)

In [90]:
class WeightedEnsemble(nn.Module):
    def __init__(self, models, num_classes):
        super(WeightedEnsemble, self).__init__()
        self.models = models
        self.num_models = len(models) 
        self.weights = nn.Parameter(torch.ones(self.num_models) / self.num_models)
        self.num_classes = num_classes

    def forward(self, inputs):
        all_preds = []
        
        # Forward pass for each model in the ensemble
        for i, model in enumerate(self.models):
            model.eval()
            with torch.no_grad():
                inputs = inputs.to(device)
                outputs = model(inputs)

                # Now pass through the head (batch_norm and linear layers)
                prob = nn.functional.softmax(outputs, dim=1)  # Apply softmax to get probabilities
                all_preds.append(prob)

        # Stack all predictions from each model
        all_preds = torch.stack(all_preds)

        # Normalize the weights (softmax ensures weights sum to 1)
        normalized_weights = nn.functional.softmax(self.weights, dim=0)

        # Aggregate the weighted predictions
        weighted_preds = torch.sum(all_preds * normalized_weights.view(-1, 1, 1), dim=0)
        return weighted_preds


In [68]:
def train_ensemble_weights(ensemble, train_loader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        ensemble.train()  # Ensemble training mode (weights can be trained)
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()  # Zero gradients for optimizer
            outputs = ensemble(inputs)  # Forward pass through ensemble
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagate
            optimizer.step()  # Update weights
            
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")


In [91]:
# List of pretrained models
models_list = [resnet50, densenet, efficientnet_b0, efficientnet_b3]

# Define the ensemble
ensemble = WeightedEnsemble(models_list, num_classes)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([ensemble.weights], lr=0.01) 

# Train the ensemble weights
train_ensemble_weights(ensemble, train_dataloader, criterion, optimizer, epochs=10)


RuntimeError: running_mean should contain 2048 elements not 150528