In [1]:
import json
import torch
import numpy as np
from torchvision import models, transforms
from PIL import Image as PilImage

from omnixai.preprocessing.image import Resize
from omnixai.data.image import Image
from omnixai.explainers.vision import VisionExplainer
from omnixai.visualization.dashboard import Dashboard

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torchray.attribution.grad_cam import grad_cam
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
model, x, category_id, _ = get_example_data()

# Grad-CAM backprop.
#saliency = grad_cam(model, x, category_id, saliency_layer='features.29')

# Plots.
#plot_example(x, saliency, 'grad-cam backprop', category_id)

UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x00000182B5C3F6F0>

In [13]:
# Load images for testing
"../dataset/test/images/sub-r040s086_ses/0110.png",
img_1 = Resize((256, 256)).transform(Image(PilImage.open('../dataset/test/images/sub-r040s086_ses/0110.png').convert('L')))
img_2 = Resize((256, 256)).transform(Image(PilImage.open('../dataset/test/images/sub-r040s086_ses/0111.png').convert('L')))
img_3 = Resize((256, 256)).transform(Image(PilImage.open('../dataset/test/images/sub-r040s086_ses/0112.png').convert('L')))
img = Image(
    data=np.concatenate([
        img_1.to_numpy(), img_2.to_numpy(), img_3.to_numpy()]),
    batched=True
)
print(img.shape)

(3, 256, 256, 1)


In [15]:
from monai.networks.nets.unet import UNet
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
# The preprocessing function
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    transforms.Normalize(mean=[0], std=[1])
])
preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims]).to(device)
# A ResNet model to explain
#model = models.resnet50(pretrained=True).to(device)
model = UNet(
            spatial_dims=2,
            in_channels=1,
            out_channels=1,
            channels=(64, 128, 256),
            strides=(2, 2),
            num_res_units=2
        )
# The postprocessing function
postprocess = lambda logits: torch.nn.functional.softmax(logits, dim=1)

In [25]:
import torch.nn as nn
class one_step_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(one_step_conv, self).__init__()
        self.conv = nn.Sequential(
            # Level 1
            nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1),            
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            # Level 2            
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1),            
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)   
        )

    def forward(self, input):
        return self.conv(input)


class UnetDeep3(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UnetDeep3, self).__init__()
        base_filter_num = 64
        self.conv_down_1 = one_step_conv(in_ch, base_filter_num)
        self.pool1 = nn.MaxPool2d(2)
        self.conv_down_2 = one_step_conv(base_filter_num, base_filter_num*2)
        self.pool2 = nn.MaxPool2d(2)
        self.conv_down_3 = one_step_conv(base_filter_num*2, base_filter_num*4)
        self.pool3 = nn.MaxPool2d(2)
        self.conv_bottom = one_step_conv(base_filter_num*4, base_filter_num*8)
        
        
        self.upsample_1 = nn.ConvTranspose2d(base_filter_num*8, base_filter_num*4, kernel_size=2, stride=2)
        self.conv_up_1 = one_step_conv(base_filter_num*8, base_filter_num*4)
        self.upsample_2 = nn.ConvTranspose2d(base_filter_num*4, base_filter_num*2, kernel_size=2, stride=2)
        self.conv_up_2 = one_step_conv(base_filter_num*4, base_filter_num*2)
        self.upsample_3 = nn.ConvTranspose2d(base_filter_num*2, base_filter_num, kernel_size=2, stride=2)
        self.conv_up_3 = one_step_conv(base_filter_num*2, base_filter_num)
        self.conv_out = nn.Conv2d(base_filter_num, out_ch, 1)

    def forward(self, x):
        ###down########        
        down_1 = self.conv_down_1(x)
        pool_1 = self.pool1(down_1)
        down_2 = self.conv_down_2(pool_1)
        pool_2 = self.pool2(down_2)
        down_3 = self.conv_down_3(pool_2)
        pool_3 = self.pool3(down_3)
        bottom = self.conv_bottom(pool_3)

        up_1 = self.upsample_1(bottom)
        merge1 = torch.cat([up_1, down_3], dim=1)
        #print(merge1.shape)
        up_1_out = self.conv_up_1(merge1)
        up_2 = self.upsample_2(up_1_out)
        merge2 = torch.cat([up_2, down_2], dim=1)
        #print(merge2.shape)
        up_2_out = self.conv_up_2(merge2)
        up_3 = self.upsample_3(up_2_out)
        merge3 = torch.cat([up_3, down_1], dim=1)
        #print(merge3.shape)
        up_3_out = self.conv_up_3(merge3)
        end_out = self.conv_out(up_3_out)
        # Applying classification in the output layer        
        out = nn.Sigmoid()(end_out)
        return out


In [26]:
model = UnetDeep3(in_ch=1,out_ch=1)
model.conv_out

Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))

In [16]:
model

UNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Sequential(
            (unit

In [17]:
# Initialize a VisionExplainer
with torch.no_grad():
    explainer = VisionExplainer(
        explainers=["lime"],
        mode="classification",
        model=model,
        preprocess=preprocess,
        postprocess=postprocess
        # params={
        #     #"gradcam": {"target_layer": model.conv_out},
        #     #"gradcam": {"target_layer": model.layer4[-1]}
        #     #"ce": {"binary_search_steps": 2, "num_iterations": 5}
        # }
    )
    # Generate explanations
    local_explanations = explainer.explain(img)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [28]:
local_explanations["lime"]

<omnixai.explanations.image.mask.MaskExplanation at 0x1a8a06876d0>

In [11]:
index=0
idx2label = "Stroke Detection"
print("Grad-CAM results:")
local_explanations["gradcam"].ipython_plot(index, class_names=idx2label)
#print("LIME results:")
#local_explanations["lime"].ipython_plot(index, class_names=idx2label)
#print("Integrated-gradient results:")
#local_explanations["ig"].ipython_plot(index, class_names=idx2label)
#print("Counterfactual results:")
#local_explanations["ce"].ipython_plot(index, class_names=idx2label)

Grad-CAM results:
