In [15]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [16]:
    # Define model paths
model_paths = {
    'lite': 'checkpoints/lite_unet.pth',
    'resnet': 'checkpoints/resnet_unet.pth',
    'resnet_att': 'checkpoints/best_resnet101_attention_model.pth'
}

In [25]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from models.lite_unet_model import UNetLite
from models.resunet_model import get_resnet_unet
from models.resunet_att_model import get_resnet_unet as get_resnet_att_unet

def load_model(model_type, model_path):
    """Load a model from checkpoint"""
    if model_type == 'lite':
        model = UNetLite(in_channels=1, out_channels=6)
    elif model_type == 'resnet':
        model = get_resnet_unet(in_channels=1, out_channels=6)
    elif model_type == 'resnet_att':
        model = get_resnet_att_unet(in_channels=1, out_channels=6)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Load state dict
    state_dict = torch.load(model_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    return model

def preprocess_image(image_path):
    """Preprocess the image for model input"""
    # Load image
    img = Image.open(image_path)
    img_array = np.array(img, dtype=np.float32)
    
    # Normalize the image to [-1, 1] range
    img_array = img_array / 127.5 - 1.0
    img_tensor = torch.from_numpy(img_array).unsqueeze(0).float()  # Add channel dimension
    
    return img_tensor.unsqueeze(0)  # Add batch dimension

def predict_and_visualize(image_path, model_paths, output_path, model_type):
    """Predict with three models and visualize results"""
    # Load models
    if model_type == 'UNetLite':
        model = load_model('lite', model_paths['lite'])
    elif model_type == 'ResNetUNet':
        model = load_model('resnet', model_paths['resnet'])
    elif model_type == 'ResNetUNet-Att':
        model = load_model('resnet_att', model_paths['resnet_att'])
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    # Preprocess image
    input_tensor = preprocess_image(image_path)
    
    # Get predictions
    predictions = {}
    with torch.no_grad():
        pred = model(input_tensor)
        pred = torch.softmax(pred, dim=1)
        predictions[model_type] = pred.squeeze().numpy()
    
    # Create visualization
    fig, axes = plt.subplots(1, 7, figsize=(21, 3))
    
    # Plot original image
    original_img = np.array(Image.open(image_path))
    axes[0].imshow(original_img, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Plot predictions for UNetLite only
    class_names = ['Background', 'EDH', 'SDH', 'SAH', 'IPH', 'IVH']
    pred = predictions[model_type]
    for j in range(6):
        axes[j+1].imshow(pred[j], cmap='hot')
        axes[j+1].set_title(f'UNetLite\n{class_names[j]}')
        axes[j+1].axis('off')
    
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()


In [26]:
# Example usage
image_path = 'test_image.png'  # Replace with your test image path
output_path = 'predictions.png'
model_type = 'ResNetUNet-Att' # 'UNetLite', 'ResNetUNet', 'ResNetUNet-Att'
    
predict_and_visualize(image_path, model_paths, output_path, model_type)

  state_dict = torch.load(model_path, map_location='cpu')
