# Inference for Chest X-Ray Classification

This notebook handles the inference part of the project. It loads the trained model and uses it to predict the class of a given chest X-ray image.

In [1]:
import torch
from torchvision import transforms
from torchvision.models.resnet import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
from pathlib import Path
import numpy as np
from PIL import Image

In [2]:
def load_model(model_path):
    """
    Load the trained model from the given path.
    
    Args:
        model_path (str): Path to the saved model.
    
    Returns:
        model (nn.Module): Loaded model.
    """
    model = resnet18(weights=None)
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, 5)  # 5 classes
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    return model

In [3]:
class Predictor:
    def __init__(self, model):
        """
        Initialize the Predictor with the trained model.
        
        Args:
            model (nn.Module): Trained model.
        """
        self.model = model
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def preprocess_image(self, img_path):
        """
        Preprocess the image for model input.
        
        Args:
            img_path (str): Path to the image file.
        
        Returns:
            image (torch.Tensor): Preprocessed image tensor.
        """
        image = Image.open(img_path)
        image = self.transform(image).unsqueeze(0)
        return image

    def predict(self, img_path):
        """
        Predict the class of the given image.
        
        Args:
            img_path (str): Path to the image file.
        
        Returns:
            str: Predicted class label.
        """
        processed_img = self.preprocess_image(img_path).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        with torch.no_grad():
            outputs = self.model(processed_img)
            _, predicted = torch.max(outputs, 1)
        classes = ["COVID-19", "Lung-Opacity", "Normal", "Viral Pneumonia", "Tuberculosis"]
        return classes[predicted.item()]