In [None]:
import torch
from torchvision import models, transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
from gradcam import GradCAM, GradCAMpp
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import cv2

from attribute_predictor import AttributePredictor
from attribute_predictor_SEBlocks import AttributePredictor_SEB
from gradcam.utils import visualize_cam


# Load and prepare the model
def get_image_encoder(pretrained=True):
    model = models.resnet50(pretrained=True)
    model.fc = torch.nn.Identity()
    
    # Infer the output size of the image encoder
    with torch.inference_mode():
        out = model(torch.randn(5, 3, 224, 224))
    assert out.dim() == 2
    assert out.size(0) == 5
    image_encoder_output_dim = out.size(1)
    
    return model, image_encoder_output_dim

class GradCAMWrapper(torch.nn.Module):
    def __init__(self, model, output_index=0):
        super().__init__()
        self.model = model
        self.output_index = output_index
        
    def forward(self, x):
        return self.model(x)[self.output_index]

def denormalize(tensor, mean, std):
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
    return tensor * std[:, None, None] + mean[:, None, None]

def load_model(checkpoint_path):
    image_encoder, image_encoder_output_dim = get_image_encoder(pretrained=True)
    attribute_sizes = [6]
    
    if checkpoint_path == "./log/best_model_SEB1.pth":
        model = AttributePredictor_SEB(attribute_sizes,image_encoder_output_dim, image_encoder)
    else:
        model = AttributePredictor(attribute_sizes, image_encoder_output_dim, image_encoder)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model'])
    model.eval()
    return model

def image_preprocessing(original_image, checkpoint_name):
    
    if checkpoint_name == "Segmented nucleus model":
        #convert PIL image to opencv format
        image = np.array(original_image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        
        # Convert to HSV color space and split channels
        hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv_image)

        # Split RGB channels
        b, g, r = cv2.split(image)

        # Subtract the S channel with the G channel
        subtracted_image = cv2.subtract(s, g)
        
        # Threshold the subtracted image
        _, thresh = cv2.threshold(subtracted_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # Dilate the thresholded image 
        kernel = np.ones((5,5),np.uint8)
        dilated_thresh = cv2.dilate(thresh, kernel, iterations = 1)

        # Convert the binary threshold image to 3 channels
        thresh_3_channel = cv2.merge([dilated_thresh, dilated_thresh, dilated_thresh])

        # Element-wise multiplication of the binary threshold with the original image
        segmented_image = cv2.multiply(image, thresh_3_channel, scale=1/255)

        # convert to BGR format
        segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)
        # convert to PIL format
        processed_img = Image.fromarray(segmented_image)

        return processed_img
    
    elif checkpoint_name == "Nucleus crop model":

        min_size=(150, 150)
        #convert PIL image to opencv format
        image = np.array(original_image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                
        # Convert to image to HSV color space and split the channels
        HSV_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        H, S, V = cv2.split(HSV_image)

        # Split BGR channels 
        B, G, R = cv2.split(image)

        # Subtract the S channel with the G channel
        subtracted_image = cv2.subtract(S, G)
        
        # Threshold the subtracted image
        ret, thresholded_image = cv2.threshold(subtracted_image, 0, 255, cv2.THRESH_OTSU)
        
        # Dilate the thresholded image to improve contour detection
        kernel = np.ones((5,5),np.uint8)
        dilated_threshold_image = cv2.dilate(thresholded_image, kernel, iterations = 1)

        # Find contours
        contours, hierarchy = cv2.findContours(dilated_threshold_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Assuming the largest contour is the nucleus, if not empty
        if contours:
    
            largest_contour = max(contours, key=cv2.contourArea)
            original_x, original_y, original_w, original_h = cv2.boundingRect(largest_contour)
            
            # Calculate the center of the original bounding box
            original_center_x = original_x + original_w // 2
            original_center_y = original_y + original_h // 2

            # Enforce minimum size, ensuring it's centered around the original bounding box
            w = max(original_w, min_size[0])
            h = max(original_h, min_size[1])

            # Adjust x and y to crop the image around the center of the bounding box
            new_x = max(original_center_x - w // 2, 0)
            new_y = max(original_center_y - h // 2, 0)

            # Adjust the end points, making sure we don't go out of the image boundaries
            new_x_end = min(new_x + w, image.shape[1])
            new_y_end = min(new_y + h, image.shape[0])

            # Correct the coordinates if they go out of bounds
            if new_x_end > image.shape[1]:
                new_x = image.shape[1] - w
            if new_y_end > image.shape[0]:
                new_y = image.shape[0] - h

            # Crop the image with the adjusted coordinates
            cropped_nucleus = image[new_y:new_y_end, new_x:new_x_end]
        
            # convert back to RGB format for conversion back to PIL
            cropped_nucleus = cv2.cvtColor(cropped_nucleus, cv2.COLOR_BGR2RGB)
            # convert to PIL format
            processed_image = Image.fromarray(cropped_nucleus)
        return processed_image
    
    else:
        return original_image
    


# Gradio function to handle image input, model prediction, and visualization
def predict_and_visualize(original_image, checkpoint_name):
    
    torch.cuda.empty_cache()
    # Transform the input image to match the model's expected input
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    

    original_image = image_preprocessing(original_image, checkpoint_name)
    img = transform(original_image).unsqueeze(0).to(device)
    checkpoint_path = checkpoint_path_all[checkpoint_name]  
    model = load_model(checkpoint_path)
    model = model.to(device)

    # Prediction and Probability Calculation
    with torch.no_grad():
        predictions = model(img)
        
    attribute_names = ["nucleus_shape"]
    attribute_values = [
        ["irregular", "segmented-bilobed", "segmented-multilobed", "unsegmented-band", "unsegmented-indented", "unsegmented-round"]
    ]

    # Collect all predictions
    prediction_texts = []
    
    for i, logits in enumerate(predictions):
        probabilities = F.softmax(logits, dim=1)
        predicted_index = torch.argmax(probabilities, dim=1)
        predicted_label = attribute_values[i][predicted_index.item()]
        all_probabilities = probabilities.squeeze().tolist()
        
        print(f"Predictions for {attribute_names[i]}:")
        for class_index, class_probability in enumerate(all_probabilities):
            print(f"{attribute_values[i][class_index]}: {class_probability*100:.2f}%")
        print(f"Most likely: {predicted_label}, Probability: {all_probabilities[predicted_index.item()]*100:.2f}%\n")

    # Grad-CAM setup
    target_layer = model.image_encoder.layer4[-1]
    gradcam_model_wrapper = GradCAMWrapper(model, output_index=0)
    gradcam = GradCAM(gradcam_model_wrapper, target_layer)
    gradcam_pp = GradCAMpp(gradcam_model_wrapper, target_layer)

    mask, _ = gradcam(img)
    heatmap, result = visualize_cam(mask, img)
    if heatmap.ndim == 3 and heatmap.shape[0] == 3:
        heatmap = heatmap[0]
        
    img_denorm = denormalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    img_denorm = torch.clamp(img_denorm, 0, 1)

    heatmap_norm = (heatmap.squeeze().cpu() - heatmap.min()) / (heatmap.max() - heatmap.min())
    colored_heatmap = plt.cm.jet(heatmap_norm.numpy())  # This applies the 'jet' colormap

    # Convert colored heatmap to an image (discard the alpha channel)
    heatmap_img = Image.fromarray((colored_heatmap[:, :, :3] * 255).astype(np.uint8))

    # Combine the heatmap with the original image
    img_pil = to_pil_image(img_denorm.squeeze()).convert("RGB")
    heatmap_on_image = Image.blend(img_pil, heatmap_img, alpha=0.4)
    
    # Combine the predictions into one string
    combined_predictions = "\n".join(prediction_texts)

    return original_image, heatmap_img, heatmap_on_image, combined_predictions 

checkpoint_path_all = {
    "Default model" : "./log/best_model_nucleus.pth",
    "Nucleus crop model" : "./log/best_model_nucleus_crop.pth",
    "Segmented nucleus model" : "./log/best_model_segmented.pth",
    "Squeeze-And-Excitation model" : "./log/best_model_SEB1.pth"
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Gradio Interface
iface = gr.Interface(fn=predict_and_visualize,
                         inputs=[
                            gr.Image(type="pil"),
                            gr.Dropdown(choices=checkpoint_path_all, label="Select Model")
                        ],
                     outputs=[gr.Image(type="pil", label="Original Image"),
                              gr.Image(type="pil", label="Heatmap"),
                              gr.Image(type="pil", label="Result on Image"),
                              gr.Textbox(label="Predictions")],
                     title="Attribute Prediction with Grad-CAM Visualization",  
                     description="Upload an image to predict attributes and visualize the model's focus areas.",)                    
iface.launch()

In [1]:
import torch
from torchvision import models, transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
from gradcam import GradCAM, GradCAMpp
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import cv2
import gc

from attribute_predictor import AttributePredictor
from attribute_predictor_SEBlocks import AttributePredictor_SEB
from gradcam.utils import visualize_cam


# Load and prepare the model
def get_image_encoder(pretrained=True):
    model = models.resnet50(pretrained=True)
    model.fc = torch.nn.Identity()
    
    # Infer the output size of the image encoder
    with torch.inference_mode():
        out = model(torch.randn(5, 3, 224, 224))
    assert out.dim() == 2
    assert out.size(0) == 5
    image_encoder_output_dim = out.size(1)
    
    return model, image_encoder_output_dim

class GradCAMWrapper(torch.nn.Module):
    def __init__(self, model, output_index=0):
        super().__init__()
        self.model = model
        self.output_index = output_index
        
    def forward(self, x):
        return self.model(x)[self.output_index]

def denormalize(tensor, mean, std):
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
    return tensor * std[:, None, None] + mean[:, None, None]

def load_model(checkpoint_path):
    image_encoder, image_encoder_output_dim = get_image_encoder(pretrained=True)
    attribute_sizes = [6]
    
    if checkpoint_path == "./log/best_model_SEB.pth":
        model = AttributePredictor_SEB(attribute_sizes,image_encoder_output_dim, image_encoder)
    else:
        model = AttributePredictor(attribute_sizes, image_encoder_output_dim, image_encoder)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model'])
    model.eval()
    return model

def image_preprocessing(original_image, checkpoint_name):
    
    if checkpoint_name == "Segmented nucleus model":
        #convert PIL image to opencv format
        image = np.array(original_image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        
        # Convert to HSV color space and split channels
        hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv_image)

        # Split RGB channels
        b, g, r = cv2.split(image)

        # Subtract the S channel with the G channel
        subtracted_image = cv2.subtract(s, g)
        
        # Threshold the subtracted image
        _, thresh = cv2.threshold(subtracted_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # Dilate the thresholded image 
        kernel = np.ones((5,5),np.uint8)
        dilated_thresh = cv2.dilate(thresh, kernel, iterations = 1)

        # Convert the binary threshold image to 3 channels
        thresh_3_channel = cv2.merge([dilated_thresh, dilated_thresh, dilated_thresh])

        # Element-wise multiplication of the binary threshold with the original image
        segmented_image = cv2.multiply(image, thresh_3_channel, scale=1/255)

        # convert to BGR format
        segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)
        # convert to PIL format
        processed_img = Image.fromarray(segmented_image)

        return processed_img
    
    elif checkpoint_name == "Nucleus crop model":

        min_size=(150, 150)
        #convert PIL image to opencv format
        image = np.array(original_image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                
        # Convert to image to HSV color space and split the channels
        HSV_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        H, S, V = cv2.split(HSV_image)

        # Split BGR channels 
        B, G, R = cv2.split(image)

        # Subtract the S channel with the G channel
        subtracted_image = cv2.subtract(S, G)
        
        # Threshold the subtracted image
        ret, thresholded_image = cv2.threshold(subtracted_image, 0, 255, cv2.THRESH_OTSU)
        
        # Dilate the thresholded image to improve contour detection
        kernel = np.ones((5,5),np.uint8)
        dilated_threshold_image = cv2.dilate(thresholded_image, kernel, iterations = 1)

        # Find contours
        contours, hierarchy = cv2.findContours(dilated_threshold_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Assuming the largest contour is the nucleus, if not empty
        if contours:
    
            largest_contour = max(contours, key=cv2.contourArea)
            original_x, original_y, original_w, original_h = cv2.boundingRect(largest_contour)
            
            # Calculate the center of the original bounding box
            original_center_x = original_x + original_w // 2
            original_center_y = original_y + original_h // 2

            # Enforce minimum size, ensuring it's centered around the original bounding box
            w = max(original_w, min_size[0])
            h = max(original_h, min_size[1])

            # Adjust x and y to crop the image around the center of the bounding box
            new_x = max(original_center_x - w // 2, 0)
            new_y = max(original_center_y - h // 2, 0)

            # Adjust the end points, making sure we don't go out of the image boundaries
            new_x_end = min(new_x + w, image.shape[1])
            new_y_end = min(new_y + h, image.shape[0])

            # Correct the coordinates if they go out of bounds
            if new_x_end > image.shape[1]:
                new_x = image.shape[1] - w
            if new_y_end > image.shape[0]:
                new_y = image.shape[0] - h

            # Crop the image with the adjusted coordinates
            cropped_nucleus = image[new_y:new_y_end, new_x:new_x_end]
        
            # convert back to RGB format for conversion back to PIL
            cropped_nucleus = cv2.cvtColor(cropped_nucleus, cv2.COLOR_BGR2RGB)
            # convert to PIL format
            processed_image = Image.fromarray(cropped_nucleus)
        return processed_image
    
    else:
        return original_image
    


# Gradio function to handle image input, model prediction, and visualization
def predict_and_visualize(original_image, checkpoint_name):
    
    with torch.no_grad():
        torch.cuda.empty_cache()
        
    obj = None
    gc.collect()
    
    # Transform the input image to match the model's expected input
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    

    original_image = image_preprocessing(original_image, checkpoint_name)
    img = transform(original_image).unsqueeze(0).to(device)
    checkpoint_path = checkpoint_path_all[checkpoint_name]  
    model = load_model(checkpoint_path)
    model = model.to(device)

    with torch.no_grad():
        predictions = model(img)

    # Handle output if it's a list
    if isinstance(predictions, list):
        # Assuming the logits are the first element of the list
        logits = predictions[0]
    else:
        logits = predictions

    probabilities = F.softmax(logits.squeeze(), dim=0)
    probabilities = probabilities.cpu().numpy()

    attribute_values = ["irregular", "segmented-bilobed", "segmented-multilobed", "unsegmented-band", "unsegmented-indented", "unsegmented-round"]
    results = {attribute_values[i]: float(probabilities[i]) for i in range(len(attribute_values))}


    # Grad-CAM setup
    target_layer = model.image_encoder.layer4[-1]
    gradcam_model_wrapper = GradCAMWrapper(model, output_index=0)
    gradcam = GradCAM(gradcam_model_wrapper, target_layer)
    gradcam_pp = GradCAMpp(gradcam_model_wrapper, target_layer)

    mask, _ = gradcam(img)
    heatmap, result = visualize_cam(mask, img)
    if heatmap.ndim == 3 and heatmap.shape[0] == 3:
        heatmap = heatmap[0]
        
    img_denorm = denormalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    img_denorm = torch.clamp(img_denorm, 0, 1)

    heatmap_norm = (heatmap.squeeze().cpu() - heatmap.min()) / (heatmap.max() - heatmap.min())
    colored_heatmap = plt.cm.jet(heatmap_norm.numpy())  # This applies the 'jet' colormap

    # Convert colored heatmap to an image (discard the alpha channel)
    heatmap_img = Image.fromarray((colored_heatmap[:, :, :3] * 255).astype(np.uint8))

    # Combine the heatmap with the original image
    img_pil = to_pil_image(img_denorm.squeeze()).convert("RGB")
    heatmap_on_image = Image.blend(img_pil, heatmap_img, alpha=0.4)

    return original_image, heatmap_on_image, results 
    # return original_image, heatmap_img, heatmap_on_image, results 

checkpoint_path_all = {
    "Default model" : "./log/best_model_nucleus.pth",
    "Nucleus crop model" : "./log/best_model_nucleus_crop.pth",
    "Segmented nucleus model" : "./log/best_model_segmented.pth",
    "Squeeze-And-Excitation model" : "./log/best_model_SEB.pth"
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Gradio Interface
iface = gr.Interface(fn=predict_and_visualize,
                         inputs=[
                            gr.Image(type="pil"),
                            gr.Dropdown(choices=checkpoint_path_all, label="Select Model")
                        ],
                     outputs=[gr.Image(type="pil", label="Input Image", width=224, height=224),
                            #   gr.Image(type="pil", label="Heatmap",width=224, height=224),
                              gr.Image(type="pil", label="Result on Image", width=224, height=224),
                              gr.Label(num_top_classes=6, label="Prediction"), ],
                     title="Attribute Prediction with Grad-CAM Visualization",  
                     description="Upload an image to predict attributes and visualize the model's focus areas.",)                    
iface.launch()

  from .autonotebook import tqdm as notebook_tqdm


Running on local URL:  http://127.0.0.1:7863

To create a public link, set `share=True` in `launch()`.




