In [58]:
import torch
from torchvision import models, transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
from gradcam import GradCAM, GradCAMpp
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import cv2
import gc
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader


from attribute_predictor import AttributePredictor
from attribute_predictor_SEBlocks import AttributePredictor_SEB
from gradcam.utils import visualize_cam

# Functions for load and running model on test dataset
class TestDataset(Dataset):
    def __init__(self, csv_file, base_path, checkpoint_name):
        self.data_frame = pd.read_csv(csv_file)
        self.base_path = base_path
        self.checkpoint_name = checkpoint_name

        # Define a mapping of labels to indices
        self.label_mapping = {
            "irregular": 0,
            "segmented-bilobed": 1,
            "segmented-multilobed": 2,
            "unsegmented-band": 3,
            "unsegmented-indented": 4,
            "unsegmented-round": 5
        }

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_path = os.path.join(self.base_path, self.data_frame.iloc[idx, -1])
        label_name = self.data_frame.iloc[idx, 4]
        label = self.label_mapping[label_name]  # Convert label names to indices

        image = Image.open(img_path)
        image = image_preprocessing(image,checkpoint_name=self.checkpoint_name)
        image = transform(image).to(device)

        
        return image, torch.tensor(label, dtype=torch.long)

def evaluate_accuracy(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)  # outputs should be a list of tensors

            # If you're evaluating a specific attribute, select the correct tensor
            # For instance, if the first tensor corresponds to `nucleus_shape`
            outputs_for_attribute = outputs[0]  # Adjust index based on your model's output

            _, predicted = torch.max(outputs_for_attribute, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

############################################################################################################

# Load and prepare the model
def get_image_encoder(pretrained=True):
    model = models.resnet50(pretrained=True)
    model.fc = torch.nn.Identity()
    
    # Infer the output size of the image encoder
    with torch.inference_mode():
        out = model(torch.randn(5, 3, 224, 224))
    assert out.dim() == 2
    assert out.size(0) == 5
    image_encoder_output_dim = out.size(1)
    
    return model, image_encoder_output_dim

class GradCAMWrapper(torch.nn.Module):
    def __init__(self, model, output_index=0):
        super().__init__()
        self.model = model
        self.output_index = output_index
        
    def forward(self, x):
        return self.model(x)[self.output_index]

def denormalize(tensor, mean, std):
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
    return tensor * std[:, None, None] + mean[:, None, None]

def load_model(checkpoint_path):
    image_encoder, image_encoder_output_dim = get_image_encoder(pretrained=True)
    attribute_sizes = [6]
    
    if checkpoint_path == "./log/best_model_SEB.pth":
        model = AttributePredictor_SEB(attribute_sizes,image_encoder_output_dim, image_encoder)
    else:
        model = AttributePredictor(attribute_sizes, image_encoder_output_dim, image_encoder)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model'])
    model.eval()
    return model

def image_preprocessing(original_image, checkpoint_name):
    
    if checkpoint_name == "Segmented nucleus model":
        #convert PIL image to opencv format
        image = np.array(original_image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        
        # Convert to HSV color space and split channels
        hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv_image)

        # Split RGB channels
        b, g, r = cv2.split(image)

        # Subtract the S channel with the G channel
        subtracted_image = cv2.subtract(s, g)
        
        # Threshold the subtracted image
        _, thresh = cv2.threshold(subtracted_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # Dilate the thresholded image 
        kernel = np.ones((5,5),np.uint8)
        dilated_thresh = cv2.dilate(thresh, kernel, iterations = 1)

        # Convert the binary threshold image to 3 channels
        thresh_3_channel = cv2.merge([dilated_thresh, dilated_thresh, dilated_thresh])

        # Element-wise multiplication of the binary threshold with the original image
        segmented_image = cv2.multiply(image, thresh_3_channel, scale=1/255)

        # convert to BGR format
        segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)
        # convert to PIL format
        processed_img = Image.fromarray(segmented_image)

        return processed_img
    
    elif checkpoint_name == "Nucleus crop model":

        min_size=(150, 150)
        #convert PIL image to opencv format
        image = np.array(original_image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                
        # Convert to image to HSV color space and split the channels
        HSV_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        H, S, V = cv2.split(HSV_image)

        # Split BGR channels 
        B, G, R = cv2.split(image)

        # Subtract the S channel with the G channel
        subtracted_image = cv2.subtract(S, G)
        
        # Threshold the subtracted image
        ret, thresholded_image = cv2.threshold(subtracted_image, 0, 255, cv2.THRESH_OTSU)
        
        # Dilate the thresholded image to improve contour detection
        kernel = np.ones((5,5),np.uint8)
        dilated_threshold_image = cv2.dilate(thresholded_image, kernel, iterations = 1)

        # Find contours
        contours, hierarchy = cv2.findContours(dilated_threshold_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Assuming the largest contour is the nucleus, if not empty
        if contours:
    
            largest_contour = max(contours, key=cv2.contourArea)
            original_x, original_y, original_w, original_h = cv2.boundingRect(largest_contour)
            
            # Calculate the center of the original bounding box
            original_center_x = original_x + original_w // 2
            original_center_y = original_y + original_h // 2

            # Enforce minimum size, ensuring it's centered around the original bounding box
            w = max(original_w, min_size[0])
            h = max(original_h, min_size[1])

            # Adjust x and y to crop the image around the center of the bounding box
            new_x = max(original_center_x - w // 2, 0)
            new_y = max(original_center_y - h // 2, 0)

            # Adjust the end points, making sure we don't go out of the image boundaries
            new_x_end = min(new_x + w, image.shape[1])
            new_y_end = min(new_y + h, image.shape[0])

            # Correct the coordinates if they go out of bounds
            if new_x_end > image.shape[1]:
                new_x = image.shape[1] - w
            if new_y_end > image.shape[0]:
                new_y = image.shape[0] - h

            # Crop the image with the adjusted coordinates
            cropped_nucleus = image[new_y:new_y_end, new_x:new_x_end]
        
            # convert back to RGB format for conversion back to PIL
            cropped_nucleus = cv2.cvtColor(cropped_nucleus, cv2.COLOR_BGR2RGB)
            # convert to PIL format
            processed_image = Image.fromarray(cropped_nucleus)
        return processed_image
    
    else:
        return original_image
    


# Gradio function to handle image input, model prediction, and visualization
def predict_and_visualize(original_image, checkpoint_name):
    
    with torch.no_grad():
        torch.cuda.empty_cache()
        
    obj = None
    gc.collect()
    
    original_image = image_preprocessing(original_image, checkpoint_name)
    img = transform(original_image).unsqueeze(0).to(device)
    checkpoint_path = checkpoint_path_all[checkpoint_name]  
    model = load_model(checkpoint_path)
    model = model.to(device)
    
    # Create the dataset and DataLoader
    test_dataset = TestDataset(csv_file=csv_file, base_path=base_path, checkpoint_name=checkpoint_name)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # Calculate accuracy
    accuracy = evaluate_accuracy(model, test_loader, device)
    # print(f'Accuracy of the model on the test images: {accuracy:.5f}%')
    accuracy = '%.5f'%accuracy

    with torch.no_grad():
        predictions = model(img)

    # Handle output if it's a list
    if isinstance(predictions, list):
        # Assuming the logits are the first element of the list
        logits = predictions[0]
    else:
        logits = predictions

    probabilities = F.softmax(logits.squeeze(), dim=0)
    probabilities = probabilities.cpu().numpy()

    attribute_values = ["irregular", "segmented-bilobed", "segmented-multilobed", "unsegmented-band", "unsegmented-indented", "unsegmented-round"]
    results = {attribute_values[i]: float(probabilities[i]) for i in range(len(attribute_values))}


    # Grad-CAM setup
    target_layer = model.image_encoder.layer4[-1]
    gradcam_model_wrapper = GradCAMWrapper(model, output_index=0)
    gradcam = GradCAM(gradcam_model_wrapper, target_layer)
    # gradcam_pp = GradCAMpp(gradcam_model_wrapper, target_layer)

    mask, _ = gradcam(img)
    heatmap, result = visualize_cam(mask, img)
    if heatmap.ndim == 3 and heatmap.shape[0] == 3:
        heatmap = heatmap[0]
        
    img_denorm = denormalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    img_denorm = torch.clamp(img_denorm, 0, 1)

    heatmap_norm = (heatmap.squeeze().cpu() - heatmap.min()) / (heatmap.max() - heatmap.min())
    colored_heatmap = plt.cm.jet(heatmap_norm.numpy())  # This applies the 'jet' colormap

    # Convert colored heatmap to an image (discard the alpha channel)
    heatmap_img = Image.fromarray((colored_heatmap[:, :, :3] * 255).astype(np.uint8))

    # Combine the heatmap with the original image
    img_pil = to_pil_image(img_denorm.squeeze()).convert("RGB")
    heatmap_on_image = Image.blend(img_pil, heatmap_img, alpha=0.4)

    return original_image, heatmap_on_image, results, accuracy
    # return original_image, heatmap_img, heatmap_on_image, results 

checkpoint_path_all = {
    "Default model" : "./log/best_model_nucleus.pth",
    "Nucleus crop model" : "./log/best_model_nucleus_crop.pth",
    "Segmented nucleus model" : "./log/best_model_segmented.pth",
    "Squeeze-And-Excitation model" : "./log/best_model_SEB.pth"
}

# Transform the input image to match the model's expected input
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
csv_file='pbc_attr_v1_val.csv'
base_path='./data/PBC/'

sample_images = [
    ["./data/PBC/PBC_dataset_normal_DIB/neutrophil/BNE_810657.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/neutrophil/BNE_681139.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/neutrophil/BNE_330256.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/monocyte/MO_574699.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/neutrophil/SNE_153895.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/neutrophil/BNE_961533.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/monocyte/MO_139718.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/lymphocyte/LY_654739.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/lymphocyte/LY_623093.jpg"],
    ["./data/PBC/PBC_dataset_normal_DIB/neutrophil/SNE_799574.jpg"]
]

_HEADER_ = '''
<h2>Final Year Project: Cell Image Staining and Analysis using Robust AI</h2>
'''

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown(_HEADER_)
    with gr.Row(variant="panel"):
        with gr.Column():
            image_input = gr.Image(type="pil", label="Sample image for Grad-CAM visualization")
            model_dropdown = gr.Dropdown(choices=checkpoint_path_all, label="Select Model")
            submit_button = gr.Button()
        with gr.Column():
            with gr.Row():
                output_image = gr.Image(type="pil", label="Input Image", width=224, height=224)
                output_image_heatmap = gr.Image(type="pil", label="Result on Image", width=224, height=224)
            with gr.Row():
                output_label = gr.Label(num_top_classes=6, label="Nucleus Shape Prediction")
            with gr.Row():
                output_accuracy = gr.Label(label="Accuracy (%) of the model on the test dataset:")
    
    submit_button.click(predict_and_visualize, inputs=[image_input, model_dropdown], outputs=[output_image, output_image_heatmap, output_label, output_accuracy])
                         
    with gr.Row(variant="panel"):
        gr.Examples(
                    examples=sample_images,
                    inputs=[image_input],
                    label="Examples",
                )    
          
    description="Choose a model from the dropdown box. Select a test image to visualize the model's Gradient Class Activation Map. Then run click on submit to start the test phrase and Grad-CAM visualization.",               
demo.launch()



Running on local URL:  http://127.0.0.1:7878

To create a public link, set `share=True` in `launch()`.




