In [None]:
from pathlib import Path
import torch
from omegaconf import OmegaConf

from modules.nih_efficient_net_v2_module import NIHEfficientNetV2Module
from modules.nih_resnet_module import NIHResNetModule

model_path = Path('../lightning_logs/resnet34_epoch=17_val_auroc=0.843_top_updated.pt')
dataset_path = Path('/home/szymswiat/datasets/nih_dataset')
img_size = 384

state = torch.load(model_path.as_posix())
hparams = OmegaConf.create(state['hparams'])
thresholds = torch.tensor(state['thresholds'])
classes = OmegaConf.to_object(hparams.dynamic.classes)

if hparams.architecture == 'eff_net_v2':
    model_class = NIHEfficientNetV2Module
elif hparams.architecture == 'resnet':
    model_class = NIHResNetModule
else:
    raise ValueError()

model = model_class.load_from_file(model_path)

In [None]:
from albumentations.pytorch import ToTensorV2
from data.nih_dataset import NIHDataset
import albumentations as A
import transforms as tfm

metadata = NIHDataset.parse_dataset_meta(
    dataset_path=dataset_path,
    split_type=NIHDataset.SPLIT_OFFICIAL_VAL_FROM_TEST,
    classes=classes
)
transforms = A.Compose([
    A.Resize(img_size, img_size),
    tfm.NormalizeAlb(NIHDataset.MIN_MAX_VALUE,
                     mean=[NIHDataset.MEAN] * 3,
                     std=[NIHDataset.STD] * 3),
    ToTensorV2()
])

test_set = NIHDataset(
    dataset_path=dataset_path,
    input_df=metadata['test_df'],
    filter_by_positive_class=['Cardiomegaly']
    # mode=NIHDataset.BBOX_ONLY_MODE
)

In [None]:
if hparams.architecture == 'eff_net_v2':
    raise NotImplementedError()
elif hparams.architecture == 'resnet':
    # gradcam_target_layer = model.model.layer4[-3]
    # gradcam_target_layer = model.model.layer3[-1]
    gradcam_target_layer = model.model.layer4[0]
else:
    raise ValueError()

In [None]:
from pytorch_grad_cam import GradCAMPlusPlus

cam = GradCAMPlusPlus(model=model, target_layer=gradcam_target_layer, use_cuda=False)

In [None]:
from imageio.plugins.pillow import ndarray_to_pil
from pytorch_grad_cam.utils.image import show_cam_on_image
import progressbar as pb
import cv2
import numpy as np
import json
from math import floor

img_path = Path('images')
img_path.mkdir(exist_ok=True)


for i, (img_raw, y_true) in pb.progressbar(enumerate(test_set)):
    inference_result = {}
    classification_report = {cls: {'detected': False} for cls in classes}
    inference_result['classification_report'] = classification_report

    img = transforms(image=img_raw)['image'].float()
    y_pred = model(torch.unsqueeze(img, dim=0)).squeeze(dim=0)

    y_pred_positive = (y_pred >= thresholds).nonzero().squeeze(dim=1).detach().numpy()
    pred_positive_count = len(y_pred_positive)

    img_raw_path = test_set.get_img_path(i)
    img_root = img_path / img_raw_path.stem
    img_root.mkdir(exist_ok=True, parents=True)

    if pred_positive_count > 0:
        images = torch.unsqueeze(img, dim=0).repeat((pred_positive_count, 1, 1, 1))

        grayscale_cams = cam(images, target_category=y_pred_positive)
        for j, class_idx in enumerate(y_pred_positive):
            class_to_process = classes[class_idx]
            img_debug_path = img_root / f'{class_to_process}.png'
            img_heatmap_path  = img_root / f'{class_to_process}_heatmap.png'
            img_heatmap_rgb_path  = img_root / (img_heatmap_path.stem + '_rgb.png')

            report = classification_report[class_to_process]
            report['detected'] = True
            # report['confidence'] = floor(float(y_pred[class_idx]) * 1000) / 1000
            report['heatmap_url'] = img_heatmap_rgb_path.name

            single_cam = cv2.resize(grayscale_cams[j, :],
                                    img_raw.shape[:2],
                                    interpolation=cv2.INTER_NEAREST)
            img_norm = torch.tensor(img_raw).float().numpy() / 255

            heatmap = cv2.applyColorMap(np.uint8(255 * single_cam), cv2.COLORMAP_JET)
            heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

            # visualization = np.float32(heatmap) + img_raw
            # visualization = np.uint8(visualization / np.max(visualization) * 255)
            #
            # ndarray_to_pil(np.concatenate((visualization, np.uint8(img_raw)), axis=1)).save(img_debug_path)
            # ndarray_to_pil(np.uint8(single_cam * 255)).save(img_heatmap_path)
            ndarray_to_pil(heatmap).save(img_heatmap_rgb_path)
            ndarray_to_pil(img_raw).save(img_root / img_raw_path.name)

    with open(img_root / 'inference_result.json', 'w') as f:
        json.dump(inference_result, f, indent=4, sort_keys=True)

    # if i == 10:
    #     break