In [1]:
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import timm  # For Xception model

# Device configuration (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the FNN ensemble classifier
class EnsembleFNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(EnsembleFNN, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)

# 1. Define and modify individual models
model1 = models.resnet18(pretrained=False)
model1.fc = nn.Linear(model1.fc.in_features, 2)

model4 = models.densenet201(pretrained=False)
model4.classifier = nn.Linear(model4.classifier.in_features, 2)

model5 = timm.create_model('xception', pretrained=False, num_classes=2)

# 2. Load trained weights
model1.load_state_dict(torch.load('resnet_model154.pth', map_location=device))
model4.load_state_dict(torch.load('densenet_model154.pth', map_location=device))
model5.load_state_dict(torch.load('xception_model.pth', map_location=device), strict=False)

# 3. Send models to device and set to eval mode
model1.to(device).eval()
model4.to(device).eval()
model5.to(device).eval()

# 4. Freeze the base models
for param in model1.parameters(): param.requires_grad = False
for param in model4.parameters(): param.requires_grad = False
for param in model5.parameters(): param.requires_grad = False

# 5. Load test dataset
transform_224 = transforms.Compose([
    transforms.Resize((260, 280)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset_224 = datasets.ImageFolder(root="./test", transform=transform_224)
test_dl_224 = DataLoader(test_dataset_224, batch_size=32, shuffle=False)

# 6. Define FNN ensemble classifier
fnn_input_size = 2 * 3  # 3 models × 2 outputs each
fnn_hidden_size = 16
fnn_classifier = EnsembleFNN(input_size=fnn_input_size, hidden_size=fnn_hidden_size, num_classes=2).to(device)

# 7. Load FNN weights (after you train it — optional)
# fnn_classifier.load_state_dict(torch.load('fnn_ensemble.pth', map_location=device))

# 8. Evaluate all models + FNN ensemble
correct_model1 = correct_model4 = correct_model5 = correct_fnn = 0
total = 0

with torch.no_grad():
    for xb, yb in test_dl_224:
        xb, yb = xb.to(device), yb.to(device)

        out1 = model1(xb)
        out4 = model4(xb)
        out5 = model5(xb)

        # Accuracy of each model
        correct_model1 += (out1.argmax(1) == yb).sum().item()
        correct_model4 += (out4.argmax(1) == yb).sum().item()
        correct_model5 += (out5.argmax(1) == yb).sum().item()

        # Concatenate model outputs: shape = [batch_size, 6]
        combined_features = torch.cat([out1, out4, out5], dim=1)

        # FNN prediction
        out_fnn = fnn_classifier(combined_features)
        pred_fnn = out_fnn.argmax(1)

        correct_fnn += (pred_fnn == yb).sum().item()
        total += yb.size(0)

# 9. Print accuracies
print(f'Model 1 (ResNet18) Accuracy: {correct_model1 / total * 100:.2f}%')
print(f'Model 4 (DenseNet201) Accuracy: {correct_model4 / total * 100:.2f}%')
print(f'Model 5 (Xception) Accuracy: {correct_model5 / total * 100:.2f}%')
print(f'FNN Ensemble Accuracy: {correct_fnn / total * 100:.2f}%')


  model = create_fn(
  model1.load_state_dict(torch.load('resnet_model154.pth', map_location=device))
  model4.load_state_dict(torch.load('densenet_model154.pth', map_location=device))
  model5.load_state_dict(torch.load('xception_model.pth', map_location=device), strict=False)


Model 1 (ResNet18) Accuracy: 94.81%
Model 4 (DenseNet201) Accuracy: 96.54%
Model 5 (Xception) Accuracy: 97.40%
Ensemble Test Accuracy: 98.27%
