In [None]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision.models import vgg16, resnet50

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = ImageFolder(root='Data/Testing', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
num_classes = len(os.listdir('Data/Testing'))

vgg16_model = vgg16()
vgg16_model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=num_classes)
vgg16_model.load_state_dict(torch.load('Models/best_VGG16_model.pth'))
vgg16_model.to(device)

resnet50_model = resnet50(pretrained=True)
resnet50_model.fc = torch.nn.Linear(in_features=2048, out_features=num_classes)
resnet50_model.load_state_dict(torch.load('Models/best_ResNet50_model.pth'))
resnet50_model.to(device)


def calculate_ensemble_accuracy(loader, model1, model2, device):
    correct = 0
    total = 0
    model1.eval()
    model2.eval()

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            ensemble_outputs = (outputs1 + outputs2) / 2

            _, predicted = torch.max(ensemble_outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# 3. Run the modified accuracy function on the test dataset
test_accuracy = calculate_ensemble_accuracy(test_loader, vgg16_model, resnet50_model, device)
print(f"Ensemble Test Accuracy: {test_accuracy:.2f}%")
