In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import requests
from io import BytesIO
import os
from timm import create_model

# **Set Device**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# **Image Preprocessing**
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# **Load ResNet50 Model**
def load_resnet():
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(512, 2)
    )
    model.load_state_dict(torch.load("trained_models/resnet_model.pth", map_location=device))
    model.to(device)
    model.eval()
    return model

# **Load EfficientNet Model**
def load_efficientnet():
    model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
    num_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.4),
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Linear(512, 2)
    )
    model.load_state_dict(torch.load("trained_models/efficientnet_model.pth", map_location=device), strict=False)
    model.to(device)
    model.eval()
    return model

# **Load ViT Model**
def load_vit():
    model = create_model("vit_base_patch16_224", pretrained=False)
    model.head = nn.Sequential(
        nn.Linear(model.head.in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(512, 2)
    )
    model.load_state_dict(torch.load("trained_models/vit_model.pth", map_location=device))
    model.to(device)
    model.eval()
    return model

# **Load Models**
resnet = load_resnet()
efficientnet = load_efficientnet()
vit = load_vit()

# **Function to Load Image**
def load_image(image_path_or_url):
    try:
        if image_path_or_url.startswith("http"):
            response = requests.get(image_path_or_url)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        elif os.path.exists(image_path_or_url):
            image = Image.open(image_path_or_url).convert("RGB")
        else:
            return None, "Error: Invalid path or URL"
        
        image = transform(image).unsqueeze(0).to(device)
        return image, None
    except Exception as e:
        return None, f"Error loading image: {e}"

# **Ensemble Prediction Function**
def ensemble_predict(image_path_or_url):
    image, error = load_image(image_path_or_url)
    if error:
        return error
    
    with torch.no_grad():
        resnet_output = resnet(image)
        efficientnet_output = efficientnet(image)
        vit_output = vit(image)
    
    # Convert logits to probabilities
    resnet_probs = torch.softmax(resnet_output, dim=1)
    efficientnet_probs = torch.softmax(efficientnet_output, dim=1)
    vit_probs = torch.softmax(vit_output, dim=1)
    
    # Weighted ensemble (modify weights as needed)
    weights = [0.4, 0.3, 0.3]
    ensemble_probs = (weights[0] * resnet_probs + weights[1] * efficientnet_probs + weights[2] * vit_probs)
    
    pred_class = torch.argmax(ensemble_probs, dim=1).item()
    class_names = ["Cancerous", "Non-Cancerous"]
    return f"Final Prediction: {class_names[pred_class]}"

# **Example Usage**
image_url = "https://mypenndentist.org/wp-content/uploads/2024/08/shutterstock_2091761254-Large-1024x671.jpeg"  # Replace with an actual image URL or file path
print(ensemble_predict(image_url))

Final Prediction: Non-Cancerous
