In [None]:
from pathlib import Path

import albumentations as A
import cv2
import torch
from imageio.plugins.pillow import ndarray_to_pil
from omegaconf import OmegaConf
from pytorch_grad_cam.utils.image import show_cam_on_image

import transforms as tfm
from data.nih_dataset import NIHDataset
from inference.models.efficient_net_v2_module import EfficientNetV2Module

model_path = Path('../lightning_logs/epoch=9_val_auroc=0.861_top.pt')

dataset_path = Path('/home/szymswiat/datasets/nih_dataset')
img_size = 384
# selected_class = 'Cardiomegaly'
selected_class = 'Effusion'

model = EfficientNetV2Module.load_from_file(model_path)

classes = OmegaConf.to_object(model.hparams.dynamic.classes)

metadata = NIHDataset.parse_dataset_meta(
    dataset_path=dataset_path,
    split_type=NIHDataset.SPLIT_OFFICIAL_VAL_FROM_TEST,
    classes=classes
)
transforms = A.Compose([
    A.Resize(img_size, img_size),
    tfm.NormalizeAlb(NIHDataset.MIN_MAX_VALUE,
                     mean=[NIHDataset.MEAN] * 3,
                     std=[NIHDataset.STD] * 3)
])

test_set = NIHDataset(
    dataset_path=dataset_path,
    input_df=metadata['test_df'],
    filter_by_positive_class=[selected_class],
    # mode=NIHDataset.BBOX_ONLY_MODE
)

model.eval()

i = 0

feature_maps = {}


def get_features(name):
    def hook(model, inputs, outputs):
        feature_maps[name] = outputs.detach()

    return hook


model.model.act2.register_forward_hook(get_features('pre_pool_feat'))

linear_weights = list(model.model.classifier.parameters())
cardio_weights = linear_weights[0][classes.index(selected_class)]

In [None]:
i += 1

In [None]:
assert i < len(test_set)
img_raw, y_true = test_set[i]
img = transforms(image=img_raw)['image']
y_pred = model(torch.unsqueeze(torch.tensor(img).permute(2, 0, 1), dim=0).float())
y_pred_sum = y_pred.sum()

last_conv_feat = torch.squeeze(feature_maps['pre_pool_feat'], dim=0)

cam = (last_conv_feat.permute(1, 2, 0) * cardio_weights).permute(2, 0, 1).mean(dim=0).detach().numpy()
cam += abs(cam.min())
cam /= cam.max()
# cam = 1 - cam
# cam[cam < 0.8] = 0

resized_cam = cv2.resize(cam,
           img_raw.shape[:2],
           interpolation=cv2.INTER_NEAREST)
visualization = show_cam_on_image(img_raw / 255,
                                  resized_cam,
                                  colormap=cv2.COLORMAP_JET)

ndarray_to_pil(visualization)

In [None]:
from timm.models.resnet import resnet18

model = resnet18(pretrained=True, num_classes=10)

In [None]:
from timm.models.resnet import resnet34

model = resnet34(pretrained=True, num_classes=10)

In [None]:
model