In [None]:
from pathlib import Path

import torch
from omegaconf import OmegaConf

from modules.classification.efficient_net_v2_module import EfficientNetV2Module
from modules.classification.resnet_module import ResNetModule

# 00006160_001.png
# 00014753_000.png
# 00013613_007.png good example of shitty image
# 00028730_000.png no finding

model_path = Path('../lightning_logs/resnet34_epoch=17_val_auroc=0.843_top_updated.pt')
dataset_path = Path('/home/szymswiat/datasets/nih_dataset')
img_size = 384

state = torch.load(model_path.as_posix())
hparams = OmegaConf.create(state['hparams'])
thresholds = torch.tensor(state['thresholds'])
classes = OmegaConf.to_object(hparams.dynamic.classes)

if hparams.architecture == 'eff_net_v2':
    model_class = EfficientNetV2Module
elif hparams.architecture == 'resnet':
    model_class = ResNetModule
else:
    raise ValueError()

model = model_class.create_from_state(state)
model.eval();

In [None]:
from albumentations.pytorch import ToTensorV2
from data.nih_dataset import NIHDataset
import albumentations as A
import transforms as tfm

metadata = NIHDataset.parse_dataset_meta(
    dataset_path=dataset_path,
    split_type=NIHDataset.SPLIT_OFFICIAL_VAL_FROM_TEST,
    classes=classes
)
transforms = A.Compose([
    A.Resize(img_size, img_size),
    tfm.NormalizeAlb(NIHDataset.MIN_MAX_VALUE,
                     mean=[NIHDataset.MEAN] * 3,
                     std=[NIHDataset.STD] * 3),
    ToTensorV2()
])

test_set = NIHDataset(
    dataset_path=dataset_path,
    input_df=metadata['test_df'],
    filter_by_positive_class=['Cardiomegaly']
    # mode=NIHDataset.BBOX_ONLY_MODE
)

In [None]:
if hparams.architecture == 'eff_net_v2':
    raise NotImplementedError()
elif hparams.architecture == 'resnet':
    # gradcam_target_layer = model.model.layer4[-3]
    # gradcam_target_layer = model.model.layer3[-1]
    gradcam_target_layer = model.model.layer4[0]
else:
    raise ValueError()

In [None]:
from inference_utils.cam_predictor import CamPredictorMultiLabel
from concurrent.futures import ThreadPoolExecutor

from imageio.plugins.pillow import ndarray_to_pil
import progressbar as pb
import cv2
import numpy as np
import json

img_path = Path('images')
img_path.mkdir(exist_ok=True)

executor = ThreadPoolExecutor(max_workers=2)

cam_pred = CamPredictorMultiLabel(model, thresholds, gradcam_target_layer)

In [None]:
for i, (img_raw, y_true) in pb.progressbar(enumerate(test_set)):
    inference_result = {}
    classification_report = {cls: {'detected': False} for cls in classes}
    inference_result['classification_report'] = classification_report

    img = transforms(image=img_raw)['image'].float()

    img_raw_path = test_set.get_img_path(i)
    img_root = img_path / img_raw_path.stem
    img_root.mkdir(exist_ok=True, parents=True)

    scores, cam_list = cam_pred(img)
    for class_idx, class_cam in cam_list:
        cls_name = classes[class_idx]
        img_debug_path = img_root / f'{cls_name}.png'
        img_heatmap_path  = img_root / f'{cls_name}_heatmap.png'
        img_heatmap_rgb_path  = img_root / (img_heatmap_path.stem + '_rgb.png')

        report = classification_report[cls_name]
        report['detected'] = True
        # report['confidence'] = floor(float(y_pred[class_idx]) * 1000) / 1000
        report['heatmap_url'] = img_heatmap_rgb_path.name

        single_cam = cv2.resize(class_cam.detach().numpy(),
                                img_raw.shape[:2],
                                interpolation=cv2.INTER_NEAREST)
        # img_norm = torch.tensor(img_raw).float().numpy() / 255

        heatmap = cv2.applyColorMap(np.uint8(255 * single_cam), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

        # visualization = np.float32(heatmap) + img_raw
        # visualization = np.uint8(visualization / np.max(visualization) * 255)
        #
        # ndarray_to_pil(np.concatenate((visualization, np.uint8(img_raw)), axis=1)).save(img_debug_path)
        # ndarray_to_pil(np.uint8(single_cam * 255)).save(img_heatmap_path)
        executor.submit(ndarray_to_pil(heatmap).save, img_heatmap_rgb_path)
        executor.submit(ndarray_to_pil(img_raw).save, img_root / img_raw_path.name)
        # ndarray_to_pil(heatmap).save(img_heatmap_rgb_path)
        # ndarray_to_pil(img_raw).save(img_root / img_raw_path.name)

    with open(img_root / 'inference_result.json', 'w') as f:
        json.dump(inference_result, f, indent=4, sort_keys=True)

executor.shutdown(wait=True)
    # if i == 10:
    #     break

