In [1]:
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import timm  # For Xception model

# Device configuration (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the FNN ensemble classifier
class EnsembleFNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(EnsembleFNN, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)

# 1. Define model architectures
model1 = models.resnet18(pretrained=False)
model1.fc = nn.Linear(model1.fc.in_features, 2)

model4 = models.densenet201(pretrained=False)
model4.classifier = nn.Linear(model4.classifier.in_features, 2)

model5 = timm.create_model('xception', pretrained=False)
model5.fc = nn.Linear(model5.fc.in_features, 2)

# 2. Load trained weights for base models
model1.load_state_dict(torch.load('resnet_best_model_combined.pth', map_location=device))
model4.load_state_dict(torch.load('Densenet__weight.pth', map_location=device))
model5.load_state_dict(torch.load('xception_best_model_combined.pth', map_location=device), strict=False)

# Move to device and set to evaluation mode
model1.to(device).eval()
model4.to(device).eval()
model5.to(device).eval()

# Freeze base model parameters
for model in [model1, model4, model5]:
    for param in model.parameters():
        param.requires_grad = False

# 3. Define test data transformation and loader
transform_224 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset_224 = datasets.ImageFolder(root="./test3", transform=transform_224)
test_dl_224 = DataLoader(test_dataset_224, batch_size=32, shuffle=False)

# 4. Load pretrained FNN
fnn_input_size = 3 * 2  # 3 models × 2 logits
fnn_hidden_size = 16
fnn_classifier = EnsembleFNN(fnn_input_size, fnn_hidden_size, num_classes=2).to(device)

# ✅ Load the pretrained weights for FNN
fnn_classifier.load_state_dict(torch.load('fnn_ensemble_3models.pth', map_location=device))
fnn_classifier.eval()

# 5. Initialize accuracy counters
correct_model1 = correct_model4 = correct_model5 = correct_fnn = 0
total = 0

# 6. Evaluate all models and the FNN ensemble
with torch.no_grad():
    for xb_224, yb in test_dl_224:
        xb_224, yb = xb_224.to(device), yb.to(device)

        out1 = model1(xb_224)
        out4 = model4(xb_224)
        out5 = model5(xb_224)

        preds1 = out1.argmax(dim=1)
        preds4 = out4.argmax(dim=1)
        preds5 = out5.argmax(dim=1)

        correct_model1 += (preds1 == yb).sum().item()
        correct_model4 += (preds4 == yb).sum().item()
        correct_model5 += (preds5 == yb).sum().item()

        # FNN ensemble
        combined_logits = torch.cat([out1, out4, out5], dim=1)
        fnn_output = fnn_classifier(combined_logits)
        fnn_preds = fnn_output.argmax(dim=1)

        correct_fnn += (fnn_preds == yb).sum().item()
        total += yb.size(0)

# 7. Print results
accuracy_model1 = correct_model1 / total * 100
accuracy_model4 = correct_model4 / total * 100
accuracy_model5 = correct_model5 / total * 100
ensemble_accuracy = correct_fnn / total * 100

print(f'Model 1 (ResNet18) Accuracy: {accuracy_model1:.2f}%')
print(f'Model 4 (DenseNet201) Accuracy: {accuracy_model4:.2f}%')
print(f'Model 5 (Xception) Accuracy: {accuracy_model5:.2f}%')
print(f'FNN Ensemble Accuracy: {ensemble_accuracy:.2f}%')


  model = create_fn(
  model1.load_state_dict(torch.load('resnet_best_model_combined.pth', map_location=device))
  model4.load_state_dict(torch.load('Densenet__weight.pth', map_location=device))
  model5.load_state_dict(torch.load('xception_best_model_combined.pth', map_location=device), strict=False)


Model 1 (ResNet18) Accuracy: 98.50%
Model 4 (DenseNet201) Accuracy: 98.17%
Model 5 (Xception) Accuracy: 99.17%
Ensemble Test Accuracy: 99.00%
