In [224]:
from utils.load_dataset import PlantVillageDataset
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.optim as optim 
import torch.nn as nn
from utils.preprocessing import preprocessing_img
import torchvision.models  as models
from torchvision import transforms
from test import LinearHeadModel

In [225]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [226]:
from create_datasets import train_validation_dataloader as train_dataloader, test_validation_dataloader as test_dataloader

## Import pre-trained models

In [240]:
# ResNet50
resnet50 = models.resnet50(pretrained=True)  # Load the pretrained ResNet50
for param in resnet50.parameters():
    param.requires_grad = False

# DenseNet201
densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
for param in densenet.parameters():
    param.requires_grad = False

# EfficientNetB0
efficientnet_b0 = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
for param in efficientnet_b0.parameters():
    param.requires_grad = False

# EfficientNetB3
efficientnet_b3 = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
for param in efficientnet_b3.parameters():
    param.requires_grad = False



## Import heads for each pre-trained models

In [241]:
# Load model from file
def load_model_from_file(model_path, input_dim, output_dim):
    model = LinearHeadModel(input_dim, output_dim)
    model.load_state_dict(torch.load(model_path))
    return model

In [242]:
# Constants
from dim_constants import input_dim_resnet,  input_dim_densenet, input_dim_b0, input_dim_b3
num_classes = 38
output_dim = num_classes + 1

In [243]:
resnet50_head = load_model_from_file('models/resnet50.pth', input_dim_resnet, output_dim)
densenet201_head = load_model_from_file('models/densenet201.pth', input_dim_densenet, output_dim)
efficientnet_b0_head = load_model_from_file('models/efficientnet_b0.pth', input_dim_b0, output_dim)
efficientnet_b3_head = load_model_from_file('models/efficientnet_b3.pth', input_dim_b3, output_dim)

  model.load_state_dict(torch.load(model_path))


In [244]:
# Modify head of pre-trained models
resnet50.fc = resnet50_head.to(device)
densenet.classifier = densenet201_head.to(device)
efficientnet_b0.classifier = efficientnet_b0_head.to(device)
efficientnet_b3.classifier = efficientnet_b3_head.to(device)

In [245]:
resnet50 = resnet50.to(device)
densenet = densenet.to(device)
efficientnet_b0 = efficientnet_b0.to(device)
efficientnet_b3 = efficientnet_b3.to(device)

In [233]:
# List of pretrained models
models_list = [resnet50, densenet, efficientnet_b0, efficientnet_b3]

In [234]:
class WeightedEnsemble(nn.Module):
    def __init__(self, models, num_classes):
        super(WeightedEnsemble, self).__init__()
        self.models = models
        self.num_models = len(models) 
        self.weights = nn.Parameter(torch.ones(self.num_models) / self.num_models)
        self.num_classes = num_classes

    def forward(self, inputs):
        all_preds = []
        
        # Forward pass for each model in the ensemble
        for i, model in enumerate(self.models):
            model.eval()
            with torch.no_grad():
                inputs = inputs.to(device)
                outputs = model(inputs)

                # Now pass through the head (batch_norm and linear layers)
                prob = nn.functional.softmax(outputs, dim=1)  # Apply softmax to get probabilities
                all_preds.append(prob)

        # Stack all predictions from each model
        all_preds = torch.stack(all_preds)

        # Normalize the weights (softmax ensures weights sum to 1)
        normalized_weights = nn.functional.softmax(self.weights, dim=0)

        # Aggregate the weighted predictions
        weighted_preds = torch.sum(all_preds * normalized_weights.view(-1, 1, 1), dim=0)
        return weighted_preds


In [235]:
def train_ensemble_weights(ensemble, criterion, optimizer, epochs=5):
    for epoch in range(epochs):
        ensemble.train()  # Ensemble training mode (weights can be trained)
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()  # Zero gradients for optimizer
            outputs = ensemble(inputs).to(device)  # Forward pass through ensemble
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagate
            optimizer.step()  # Update weights
            
            running_loss += loss.item()

            if i%10==0 : 
                print(f'input {i}')
                print('loss', running_loss)


            
        # Evaluate on test set
        ensemble.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_dataloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = ensemble(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracy = correct / total

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_dataloader):.4f}, Test Accuracy: {test_accuracy:.4f}")



In [236]:
batch_size = 64

In [246]:
# Evaluate on test set
resnet50.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet50(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

In [247]:
test_accuracy

0.0964920191180449

In [None]:
# Evaluate on test set
correct = 0
total = 0
with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_dataloader):
        if i%50 == 0:
            print('input', i)
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Initialize a zero tensor for averaging
        final_outputs = torch.zeros(inputs.size(0), output_dim).to(device)
        
        for model in models_list:
            model.eval()  # Ensure the model is in evaluation mode
            outputs = model(inputs).to(device)
            _, predicted = torch.max(final_outputs, 1)
            print((predicted == labels).sum().item())
            final_outputs += outputs  # Accumulate outputs from each model
        
        # Average the accumulated outputs
        final_outputs /= len(models_list)
        
        # Compute predictions and accuracy
        _, predicted = torch.max(final_outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # Calculate test accuracy
    test_accuracy = correct / total

    print(f"Test Accuracy: {test_accuracy:.4f}")


input 0
0
11
21
17
final tensor([[6.6127e-11, 3.3037e-11, 9.2770e-12,  ..., 2.3452e-15, 4.4107e-12,
         2.0284e-06],
        [1.1810e-06, 1.3949e-04, 1.4449e-07,  ..., 9.8146e-07, 1.0924e-06,
         2.5600e-01],
        [2.2851e-08, 1.4413e-08, 4.4682e-09,  ..., 1.5381e-14, 1.1702e-09,
         2.3509e-01],
        ...,
        [8.4585e-10, 1.5662e-09, 3.7343e-10,  ..., 2.0139e-10, 4.0198e-10,
         6.5663e-07],
        [8.7781e-03, 4.8363e-02, 6.3813e-05,  ..., 1.7200e-06, 9.0009e-05,
         5.0124e-01],
        [9.3433e-11, 2.0803e-11, 4.7359e-09,  ..., 4.7140e-01, 1.0944e-07,
         4.9795e-01]], device='mps:0')
predicted tensor([ 5, 38, 25, 12, 38, 26, 25, 11, 38, 38, 25, 38, 25,  5, 25, 25, 25, 25,
         5, 38, 38,  5,  5,  5, 25, 38, 25, 25, 25, 33, 25, 38, 25, 38, 25, 38,
        36, 38, 38, 25, 25, 25, 38, 25, 38, 11, 25, 25, 25, 25, 38, 38, 25, 38,
        25, 25, 38, 25, 38, 25, 25,  5, 38, 38], device='mps:0')
labels tensor([16, 31, 14, 13, 29, 26,  8, 11, 2

KeyboardInterrupt: 

In [None]:
# List of pretrained models
models_list = [resnet50, densenet, efficientnet_b0, efficientnet_b3]

# Define the ensemble
ensemble = WeightedEnsemble(models_list, output_dim)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([ensemble.weights], lr=0.01) 

# Train the ensemble weights
train_ensemble_weights(ensemble, criterion, optimizer, epochs=10)

all_preds torch.Size([4, 64, 39])
weighted_preds torch.Size([4, 64, 4])
outputs torch.Size([4, 64, 4])
labels tensor([25,  9, 10, 28, 12, 35, 30, 31, 20, 38, 38,  3, 36, 19, 26, 31, 25, 36,
        17, 31, 16, 32, 16, 16,  2,  6, 26, 12,  4, 16,  0, 38, 34, 34,  6, 31,
        16, 17, 16, 20, 17, 11, 14, 38, 12, 26, 33, 16, 25, 16, 16, 19, 32, 34,
        36,  5, 17, 35, 24, 14, 16, 20, 31,  6], device='mps:0') torch.Size([64])


ValueError: Expected input batch_size (4) to match target batch_size (64).