In [1]:
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import timm  # For Xception model

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define FNN ensemble classifier
class EnsembleFNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(EnsembleFNN, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)

# 1. Define base models
model1 = models.resnet18(pretrained=False)
model1.fc = nn.Linear(model1.fc.in_features, 2)

model3 = models.inception_v3(pretrained=False, aux_logits=False)
model3.fc = nn.Linear(model3.fc.in_features, 2)

model4 = models.densenet201(pretrained=False)
model4.classifier = nn.Linear(model4.classifier.in_features, 2)

model5 = timm.create_model('xception', pretrained=False)
model5.fc = nn.Linear(model5.fc.in_features, 2)

# 2. Load weights
model1.load_state_dict(torch.load('RESNET_WEIGHT.pth', map_location=device))
model3.load_state_dict(torch.load('best_inception_model.pth', map_location=device), strict=False)
model4.load_state_dict(torch.load('Densenet__weight.pth', map_location=device))
model5.load_state_dict(torch.load('xception_best_model.pth', map_location=device), strict=False)

# 3. Send to device and eval mode
model1.to(device).eval()
model3.to(device).eval()
model4.to(device).eval()
model5.to(device).eval()

# Freeze base models
for m in [model1, model3, model4, model5]:
    for param in m.parameters():
        param.requires_grad = False

# 4. Transforms
transform_224 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_299 = transforms.Compose([
    transforms.Resize((280, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 5. Datasets and loaders
test_dataset_224 = datasets.ImageFolder(root="./test2", transform=transform_224)
test_dataset_299 = datasets.ImageFolder(root="./test2", transform=transform_299)

test_dl_224 = DataLoader(test_dataset_224, batch_size=32, shuffle=False)
test_dl_299 = DataLoader(test_dataset_299, batch_size=32, shuffle=False)

# 6. Define and load FNN ensemble
fnn_input_size = 2 * 4  # 4 models × 2 outputs each
fnn_hidden_size = 16
fnn_classifier = EnsembleFNN(fnn_input_size, fnn_hidden_size, 2).to(device)

# Optional: Load trained FNN weights
# fnn_classifier.load_state_dict(torch.load('fnn_ensemble_4models.pth', map_location=device))

fnn_classifier.eval()

# 7. Evaluation
correct_model1 = correct_model3 = correct_model4 = correct_model5 = correct_fnn = 0
total = 0

with torch.no_grad():
    for (xb_224, yb), (xb_299, _) in zip(test_dl_224, test_dl_299):
        xb_224, yb = xb_224.to(device), yb.to(device)
        xb_299 = xb_299.to(device)

        out1 = model1(xb_224)
        out3 = model3(xb_299)
        out4 = model4(xb_224)
        out5 = model5(xb_224)

        preds1 = out1.argmax(1)
        preds3 = out3.argmax(1)
        preds4 = out4.argmax(1)
        preds5 = out5.argmax(1)

        correct_model1 += (preds1 == yb).sum().item()
        correct_model3 += (preds3 == yb).sum().item()
        correct_model4 += (preds4 == yb).sum().item()
        correct_model5 += (preds5 == yb).sum().item()

        # FNN ensemble prediction
        combined = torch.cat([out1, out3, out4, out5], dim=1)  # shape: [batch_size, 8]
        fnn_out = fnn_classifier(combined)
        fnn_preds = fnn_out.argmax(1)

        correct_fnn += (fnn_preds == yb).sum().item()
        total += yb.size(0)

# 8. Print accuracy
print(f'Model 1 (ResNet18) Accuracy: {correct_model1 / total * 100:.2f}%')
print(f'Model 3 (InceptionV3) Accuracy: {correct_model3 / total * 100:.2f}%')
print(f'Model 4 (DenseNet201) Accuracy: {correct_model4 / total * 100:.2f}%')
print(f'Model 5 (Xception) Accuracy: {correct_model5 / total * 100:.2f}%')
print(f'FNN Ensemble Accuracy: {correct_fnn / total * 100:.2f}%')


  model = create_fn(
  model1.load_state_dict(torch.load('RESNET_WEIGHT.pth', map_location=device))
  model3.load_state_dict(torch.load('best_inception_model.pth', map_location=device), strict=False)
  model4.load_state_dict(torch.load('Densenet__weight.pth', map_location=device))
  model5.load_state_dict(torch.load('xception_best_model.pth', map_location=device), strict=False)


Model 1 (ResNet18) Accuracy: 97.91%
Model 3 (InceptionV3) Accuracy: 89.56%
Model 4 (DenseNet201) Accuracy: 97.65%
Model 5 (Xception) Accuracy: 98.17%
Ensemble Test Accuracy: 98.69%
