In [1]:
import gradio as gr
import torch
from torchvision import models, transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np

# Custom imports
from attribute_predictor import AttributePredictor
from gradcam import GradCAM, GradCAMpp
from gradcam.utils import visualize_cam

def get_image_encoder(pretrained=True):
    model = models.resnet50(pretrained=pretrained)
    in_features = model.fc.in_features
    model.fc = torch.nn.Identity()
    return model, in_features

image_encoder, image_encoder_output_dim = get_image_encoder(pretrained=True)
attribute_sizes = [6]
model = AttributePredictor(attribute_sizes, image_encoder_output_dim, image_encoder)
checkpoint = torch.load('./log/best_model_nucleus.pth')
model.load_state_dict(checkpoint['model'])
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

attribute_names = ["nucleus_shape"]
attribute_values = [
    ["irregular", "segmented-bilobed", "segmented-multilobed", "unsegmented-band", "unsegmented-indented", "unsegmented-round"]
]

class GradCAMWrapper(torch.nn.Module):
    def __init__(self, model, output_index=0):
        super().__init__()
        self.model = model
        self.output_index = output_index

    def forward(self, x):
        return self.model(x)[self.output_index]

def predict(image):
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        predictions = model(img_tensor)
    
    probabilities = [F.softmax(logits, dim=1) for logits in predictions]
    results = {}
    
    for i, (probs, name) in enumerate(zip(probabilities, attribute_names)):
        class_probs = probs.squeeze().tolist()
        predicted_index = torch.argmax(probs)
        results[name] = f"{attribute_values[i][predicted_index.item()]} ({class_probs[predicted_index.item()]*100:.2f}%)"
        
        fig, ax = plt.subplots()
        ax.bar(attribute_values[i], class_probs)
        ax.set_title(f"Probabilities for {name}")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(f"temp_plot_{i}.png")
        plt.close()
        
        results[f"chart_{i}"] = f"temp_plot_{i}.png"
    
    # Grad-CAM integration
    target_layer = model.image_encoder.layer4[-1]
    gradcam_model_wrapper = GradCAMWrapper(model, output_index=0)
    gradcam = GradCAM(gradcam_model_wrapper, target_layer)
    mask, _ = gradcam(img_tensor)
    heatmap, result = visualize_cam(mask, img_tensor)

    heatmap = np.clip(heatmap.squeeze().numpy(), 0, 1)
    heatmap_image = Image.fromarray((heatmap * 255).astype('uint8'))

    return results, image, heatmap_image

iface = gr.Interface(fn=predict, inputs="image", outputs=["json", "image", "image"],
                     description="Predict Nucleus Shape and visualize using Grad-CAM.")
iface.launch()


  from .autonotebook import tqdm as notebook_tqdm


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




Traceback (most recent call last):
  File "c:\Users\tanqi\anaconda3\envs\torch\Lib\site-packages\PIL\Image.py", line 3070, in fromarray
    mode, rawmode = _fromarray_typemap[typekey]
                    ~~~~~~~~~~~~~~~~~~^^^^^^^^^
KeyError: ((1, 1, 224), '|u1')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "c:\Users\tanqi\anaconda3\envs\torch\Lib\site-packages\gradio\queueing.py", line 527, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\tanqi\anaconda3\envs\torch\Lib\site-packages\gradio\route_utils.py", line 261, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\tanqi\anaconda3\envs\torch\Lib\site-packages\gradio\blocks.py", line 1786, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\ta