<a href="https://colab.research.google.com/github/sarveshrastogi1/BroYOS/blob/main/Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DR Inference

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2
from timm import create_model
from torchvision import models

# --- DEVICE SETUP ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 224
NUM_CLASSES = 5

# --- LOAD CLASS NAMES FROM TRAINING DATA ---
def get_class_names(root_folder):
    dataset = datasets.ImageFolder(root_folder)  # Load the dataset to get the class-to-index mapping
    return dataset.classes  # List of class names in alphabetical order

# --- TRANSFORMS FOR INFERENCE ---
def get_inference_transforms():
    return Compose([
        Resize(IMG_SIZE, IMG_SIZE),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

# --- ENSEMBLE MODEL CLASS ---
class EnsembleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.efficientnet = models.efficientnet_b0(weights='IMAGENET1K_V1')
        self.efficientnet.classifier[1] = nn.Linear(self.efficientnet.classifier[1].in_features, NUM_CLASSES)
        self.vit = create_model('vit_base_patch16_224', pretrained=True, num_classes=NUM_CLASSES)

    def forward(self, x):
        out1 = self.efficientnet(x)
        out2 = self.vit(x)
        return (out1 + out2) / 2

# --- LOAD MODEL AND WEIGHTS ---
def load_model(weights_path='best_ensemble_model.pth'):
    model = EnsembleModel().to(DEVICE)
    model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
    model.eval()
    return model

# --- INFERENCE FUNCTION ---
def infer(model, image_path, class_names):
    transform = get_inference_transforms()
    image = datasets.folder.default_loader(image_path)  # Load image using torchvision's loader
    image = transform(image=np.array(image))['image'].unsqueeze(0).to(DEVICE)  # Apply transforms and add batch dim

    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()

    class_name = class_names[predicted_class]
    highest_prob = probabilities[0, predicted_class].item()

    return predicted_class, class_name, highest_prob, probabilities.cpu().numpy()

# --- MAIN FUNCTION FOR TESTING ---
if __name__ == '__main__':
    # Replace with the path to the training folders
    class_names = ['mild', 'moderate', 'no_dr', 'poliferate', 'severe']  # Get class names from training data

    model = load_model(weights_path='best_ensemble_model.pth')
    test_image_path = '/content/ffd97f8cd5aa.png'  # Replace with your test image path

    predicted_class, class_name, highest_prob, probabilities = infer(model, test_image_path, class_names)
    print(f'Predicted Class: {predicted_class} ({class_name})')
    print(f'Highest Probability: {highest_prob:.4f}')
    print(f'Class Probabilities: {probabilities}')


# Clot Inference

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Specify the local path where the ResNet weights are saved
local_weights_path = 'path/to/resnet18_weights.pth'

# Load the ResNet18 model architecture
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 1)  # Binary classification head

# Load the weights from the local path
try:
    state_dict = torch.load(local_weights_path, map_location=device)
    model.load_state_dict(state_dict, strict=False)
except (RuntimeError, FileNotFoundError) as e:
    print(f"Error loading the model weights: {e}")
    print("Check if the path is correct or the file is corrupted.")
    exit(1)

# Load custom classifier weights
try:
    custom_state_dict = torch.load('clot.pth', map_location=device)
    model.load_state_dict(custom_state_dict, strict=False)
except RuntimeError as e:
    print(f"Error loading the custom model state dictionary: {e}")
    exit(1)

model = model.to(device)
model.eval()  # Set model to evaluation mode

# Define the same transform used during training
transform = transforms.Compose([
    transforms.Resize((512, 384)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(image_path):
    """
    Perform inference on a single image.

    Args:
        image_path (str): Path to the input image.

    Returns:
        int: Predicted label (0 or 1).
        float: Confidence score (between 0 and 1).
    """
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image).view(-1)
        confidence = torch.sigmoid(output).item()

    prediction = 1 if confidence >= 0.5 else 0
    return prediction, confidence

# Example usage
image_path = '817.jpg'
prediction, confidence = predict(image_path)

print(f"Predicted Label: {prediction}, Confidence: {confidence:.4f}")
