In [None]:

from inference.models.efficient_net_v2_module import EfficientNetV2Module
from pathlib import Path

import albumentations as A
import numpy as np
import torch
from omegaconf import OmegaConf

import transforms as tfm
from data.nih_dataset import NIHDataset


model_path = Path('../lightning_logs/epoch=9_val_auroc=0.861_top.pt')

dataset_path = Path('/home/szymswiat/datasets/nih_dataset')
img_size = 384

model = EfficientNetV2Module.load_from_file(model_path)

classes = OmegaConf.to_object(model.hparams.dynamic.classes)

metadata = NIHDataset.parse_dataset_meta(
    dataset_path=dataset_path,
    split_type=NIHDataset.SPLIT_OFFICIAL_VAL_FROM_TEST,
    classes=classes
)
transforms = A.Compose([
    A.Resize(img_size, img_size),
    tfm.NormalizeAlb(NIHDataset.MIN_MAX_VALUE,
                     mean=[NIHDataset.MEAN] * 3,
                     std=[NIHDataset.STD] * 3)
])

test_set = NIHDataset(
    dataset_path=dataset_path,
    input_df=metadata['test_df'],
    filter_by_positive_class=['Emphysema'],
    # mode=NIHDataset.BBOX_ONLY_MODE
)

model.eval()

i = 0

In [None]:
assert i < len(test_set)
img_raw, y_true = test_set[i]
img = transforms(image=img_raw)['image']
y_pred = model(torch.unsqueeze(torch.tensor(img).permute(2, 0, 1), dim=0).float())
y_pred_sum = y_pred.sum()
i += 1

emphysema_out = y_pred[0][3] >= 0.2575
# 0.2575

In [None]:
from lime import lime_image
from skimage.segmentation import mark_boundaries
from imageio.plugins.pillow import ndarray_to_pil
import cv2


def predict(batch: np.ndarray):
    batch = torch.tensor(batch, dtype=torch.float32).permute((0, 3, 1, 2))

    outputs = model(batch)
    return outputs.detach().numpy()


explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(
    img,
    predict,  # classification function
    labels=metadata['classes'],
    hide_color=0,
    batch_size=4,
    num_samples=1000,
    top_labels=6
)

In [None]:
_, mask = explanation.get_image_and_mask(classes.index('Cardiomegaly'),
                                         positive_only=True, negative_only=False,
                                         num_features=1, hide_rest=False)
img_boundry2 = mark_boundaries(img_raw, cv2.resize(mask, img_raw.shape[:2], interpolation=cv2.INTER_NEAREST))
ndarray_to_pil(img_boundry2)