In [1]:
# Cell 1: Import Libraries

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from IPython.display import display
import ipywidgets as widgets
from ipywidgets import FileUpload


In [9]:
# Cell 2: Define the SkinDiseaseCNN Model with 11 Classes

import torch.nn as nn

class SkinDiseaseCNN(nn.Module):
    def __init__(self, num_classes=11):  # Changed to 11
        super(SkinDiseaseCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # Output: 32 x 224 x 224
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # Output: 32 x 112 x 112

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # Output: 64 x 112 x 112
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # Output: 64 x 56 x 56

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),# Output: 128 x 56 x 56
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)                  # Output: 128 x 28 x 28
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)  # Now outputs 11 classes
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x


In [10]:
# Cell 4: Define Class Names with 'Unknown' for .DS_Store

CLASS_NAMES = [
    'Unknown',  # Placeholder for unintended class like .DS_Store
    'Eczema',
    'Warts Molluscum and other Viral Infections',
    'Melanoma',
    'Atopic Dermatitis',
    'Basal Cell Carcinoma (BCC)',
    'Melanocytic Nevi (NV)',
    'Benign Keratosis-like Lesions (BKL)',
    'Psoriasis pictures Lichen Planus and related diseases',
    'Seborrheic Keratoses and other Benign Tumors',
    'Tinea Ringworm Candidiasis and other Fungal Infections'
]


In [11]:
# Cell 3: Load the Trained Model with 11 Classes

def load_model(model_path='skin_disease_model.pth', num_classes=11, device='cpu'):
    """
    Load the SkinDiseaseCNN model from a .pth file.
    
    Args:
        model_path (str): Path to the .pth file containing the state_dict.
        num_classes (int): Number of output classes (11).
        device (str): Device to load the model on ('cpu' or 'cuda').
        
    Returns:
        model (nn.Module): Loaded SkinDiseaseCNN model in evaluation mode.
    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"The model file {model_path} does not exist.")
    
    # Initialize the SkinDiseaseCNN model
    model = SkinDiseaseCNN(num_classes=num_classes)
    
    # Load the state_dict
    state_dict = torch.load(model_path, map_location=torch.device(device))
    
    # Handle cases where the state_dict was saved using DataParallel
    if isinstance(state_dict, dict) and 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    
    if isinstance(state_dict, dict):
        # Check if keys are prefixed with "module." (common when using DataParallel)
        if list(state_dict.keys())[0].startswith("module."):
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                new_key = k.replace("module.", "")
                new_state_dict[new_key] = v
            state_dict = new_state_dict
    
    # Load state_dict into the model
    model.load_state_dict(state_dict)
    
    # Move the model to the specified device and set to evaluation mode
    model.to(device)
    model.eval()
    
    return model

# Example: Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(model_path='skin_disease_model.pth', num_classes=11, device=device)
print("Model loaded successfully.")


Model loaded successfully.


  state_dict = torch.load(model_path, map_location=torch.device(device))


In [12]:
# Cell 6: Updated Prediction Function

def predict_disease(image_path, model, device='cpu'):
    """
    Predict the disease from the input image using the loaded model.
    
    Args:
        image_path (str): Path to the input image.
        model (nn.Module): Loaded machine learning model.
        device (str): Device to perform computation on ('cpu' or 'cuda').
        
    Returns:
        tuple: (Predicted disease, Confidence score)
    """
    preprocessed_image = preprocess_image(image_path, target_size=(224, 224))
    preprocessed_image = preprocessed_image.to(device)
    
    with torch.no_grad():
        outputs = model(preprocessed_image)
        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probabilities, 1)
        predicted_class = predicted_idx.item()
        confidence_score = confidence.item()
        
        if predicted_class == 0:
            disease = "Unknown Disease"
        elif predicted_class < 0 or predicted_class >= len(CLASS_NAMES):
            disease = "Invalid Prediction"
        else:
            disease = CLASS_NAMES[predicted_class]
        
        return disease, confidence_score
