<a href="https://colab.research.google.com/github/varshini0317/Knee-X-Ray-Analysis/blob/main/Model_Knee_X_ray_Analysis_with_ResNet_for_Osteoarthritis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [50]:
import pandas as pd
import os
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import models
from PIL import Image

In [51]:
num_classes = 5  # Change this if you have more or fewer classes

In [52]:
# Step 2: Load your trained model
model = models.resnet18(weights='IMAGENET1K_V1')  # Load a pretrained ResNet model
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)  # Adjust the output layer
model.load_state_dict(torch.load('/content/model.pth'))  # Load your trained model weights
model.eval()

  model.load_state_dict(torch.load('/content/model.pth'))  # Load your trained model weights


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [53]:
# Step 3: Define transformations for the input images
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match ResNet input size
    transforms.ToTensor(),  # Convert image to tensor
])

In [54]:
# Step 4: Load your CSV file
csv_file_path = '/content/multilabel_missing-filled.csv'  # Replace with your actual CSV path
data = pd.read_csv(csv_file_path)
def load_images(image_paths):
    images = []
    for img_path in image_paths:
        image = Image.open(img_path).convert('RGB')  # Open the image
        image = data_transforms(image)  # Apply transformations
        images.append(image)
    return torch.stack(images)  # Return a batch of images

In [55]:
data.head()  # Check the paths in the CSV

Unnamed: 0,id,side,subset,filename,kl_grade,osteophytes,jsn,osfl,scfl,cyfl,...,cytm,attm,osfm,scfm,cyfm,ostl,sctl,cytl,attl,actual_path
0,9000099,L,train,9000099L.png,3,def,severe,2.0,2.0,0.0,...,1.0,0.0,0.0,0.0,0.0,1.0,2.0,0.0,0.0,/kaggle/input/knee-osteoarthritis-dataset-with...
1,9000099,R,train,9000099R.png,2,def,mild/mod,2.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,/kaggle/input/knee-osteoarthritis-dataset-with...
2,9000296,L,train,9000296L.png,3,poss,def,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,/kaggle/input/knee-osteoarthritis-dataset-with...
3,9000296,R,train,9000296R.png,2,def,none,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,/kaggle/input/knee-osteoarthritis-dataset-with...
4,9000622,L,train,9000622L.png,1,none,none,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,/kaggle/input/knee-osteoarthritis-dataset-with...


In [56]:
# Step 5: Prepare DataLoader for predictions
def load_images(image_paths):
    images = []
    for img_path in image_paths:
        # Construct full path if necessary
        full_path = img_path  # If paths in CSV are absolute
        # Uncomment the line below if your paths are relative and you need to construct them
        # full_path = os.path.join('/kaggle/input/knee-osteoarthritis-dataset-with-severity', img_path)  # Adjust the base path if necessary

        # Check if the path exists
        if not os.path.exists(full_path):
            print(f"Image not found: {full_path}")  # Print an error if the image is not found
            continue  # Skip to the next image

        image = Image.open(full_path).convert('RGB')  # Open the image
        image = data_transforms(image)  # Apply transformations
        images.append(image)
    return torch.stack(images)  # Return a batch of images


In [61]:
# Class mapping for output labels
class_mapping = {
    0: "None",
    1: "Doubtful",
    2: "Minimal",
    3: "Moderate",
    4: "Severe"
}

def predict_image_class(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')  # Open the image
    image = data_transforms(image).unsqueeze(0)  # Add batch dimension

    # Make predictions
    with torch.no_grad():  # Disable gradient calculations for inference
        outputs = model(image)  # Get model predictions

    predicted_class = torch.argmax(outputs, dim=1).item()  # Get the predicted class index
    predicted_label = class_mapping.get(predicted_class, "Unknown")  # Map to human-readable label
    return predicted_class, predicted_label  # Return both class index and label

# Example usage
image_path = '/content/9057150L.png'  # Replace with your image path
predicted_class, predicted_label = predict_image_class(image_path)
print(f"Predicted class index: {predicted_class}, Predicted label: {predicted_label}")


Predicted class index: 2, Predicted label: Minimal
