In [1]:
import sys
sys.path.insert(0, "../")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt

from tqdm import tqdm
from omegaconf import OmegaConf
from cai.plmodule.module import HPASSCModule
from cai.pldatamodule.datamodule import HPASSCDataModule

In [3]:
ckpt = "../models/epoch=49-step=3400.ckpt"
cfg = OmegaConf.to_container(OmegaConf.load("/mnt/yando/Users/yando/Experiments/cai/outputs/2022-12-12/23-25-29/.hydra/config.yaml"))

In [4]:
inference_transforms = [
        [
            {
                "module" : "tv_transforms",
                 "name" : "CenterCrop",
                 "kwargs" : {
                     "size" : [1568, 1568]
                 }
            },
            {
                "module" : "tv_transforms",
                "name" : "Normalize",
                "kwargs" : {
                    "mean" : [0.06438801, 0.0441467, 0.03966651, 0.06374957],
                    "std" : [0.10712028, 0.08619478, 0.11134183, 0.10635688]
                }
            },
            {
                "module" : "tv_transforms",
                "name" : "Resize",
                "kwargs" : {
                    "size" : [224, 224]
                }
            }
        ],
        [
            {
                "module" : "tv_transforms",
                 "name" : "CenterCrop",
                 "kwargs" : {
                     "size" : [1568, 1568]
                 }
            },
            {
                "module" : "tv_transforms",
                "name" : "Normalize",
                "kwargs" : {
                    "mean" : [0.06438801, 0.0441467, 0.03966651, 0.06374957],
                    "std" : [0.10712028, 0.08619478, 0.11134183, 0.10635688]
                }
            },
            {
                "module" : "nn",
                "name" : "Unfold",
                "kwargs" : {
                    "kernel_size" : [224,224],
                    "stride" : 224
                }
            }
        ]
]

In [5]:
cfg["module"]["loss_fn"]["module"] = "nn"
cfg["datamodule"]["dataset"]["transforms"] = inference_transforms[1]
cfg["datamodule"]["dataloader"]["batch_size"]=4
cfg["datamodule"]["dataloader"]["num_workers"]=16
cfg["datamodule"]["dataloader"]["prefetch_factor"]=4
del cfg["datamodule"]["dataloader"]["sampler"]

In [6]:
device = "cuda:0"
module = HPASSCModule.load_from_checkpoint(ckpt, **cfg["module"]).to(device)
datamodule = HPASSCDataModule(**cfg["datamodule"])

In [7]:
module.eval()
datamodule.setup()

In [23]:
image, label = datamodule.train_dataset[10]

In [24]:
image = image.permute(1,0).view(49, -1, 224, 224) if len(image.shape) < 3 else image[None, :, :, :]

In [25]:
image = image.to(device)
output = module.reduce_channel(image)
output = module.model(output, output_attentions=True)
output.logits.shape, len(output.attentions), output.attentions[0].shape

(torch.Size([49, 19]), 12, torch.Size([49, 12, 197, 197]))

In [35]:
torch.softmax(output.logits, dim=1).max(dim=0).values.sort(descending=True)

torch.return_types.sort(
values=tensor([0.2268, 0.1832, 0.1250, 0.0762, 0.0750, 0.0626, 0.0621, 0.0618, 0.0609,
        0.0597, 0.0580, 0.0559, 0.0523, 0.0509, 0.0492, 0.0463, 0.0430, 0.0417,
        0.0405], device='cuda:0', grad_fn=<SortBackward0>),
indices=tensor([ 0, 16,  8, 13, 10, 17, 15,  5, 14,  4,  6,  3,  7,  1,  2, 12,  9, 18,
        11], device='cuda:0'))

In [None]:
n_tiles, nh, _, _ = output.attentions[-1].shape
attentions = output.attentions[-1][:, :, 0, 1:]
attentions = (attentions * (attentions > 0.003)).view(n_tiles, nh, -1)
for head in range(nh):
    fig, ax = plt.subplots(1, 2)
    grid = vutils.make_grid(attentions[:, head,:].view(-1, 1, 14, 14), nrow=7, ncol=7, padding=0, normalize=True)
    image_grid = vutils.make_grid(image, nrow=7, ncol=7).permute(1,2,0).mean(dim=-1).cpu()
    grid = transforms.functional.resize(grid, (1568, 1568)).cpu().permute(1,2,0)
    ax[1].imshow(grid[:,:,0], cmap=plt.cm.rainbow)
    ax[0].imshow(image_grid)
    ax[0].set_axis_off()
    ax[1].set_axis_off()
    plt.show()

In [None]:
from itertools import product
from tqdm import tqdm
num_classes = 19
num_layers = 12
num_tiles = 49
grad = torch.zeros((num_layers, *output.attentions[-1].shape)).to(device)
for cla, layer, tile in tqdm(product(range(num_classes), range(num_layers), range(num_tiles)), total=num_classes * num_layers * num_tiles):
    #grad[cla, layer] += torch.autograd.grad(output.logits[tile, cla], output.attentions[layer], retain_graph=True)[0] 


In [None]:
threshold = 0.0001
n_tiles, nh, _, _ = output.attentions[-1].shape
attentions = (output.attentions[-1] * grad)[:, :, 0, 1:]
attentions = (attentions * (attentions > threshold)).view(n_tiles, nh, -1)

fig, ax = plt.subplots(1, 2)
grid = vutils.make_grid(attentions.max(dim=1).values.view(-1, 1, 14, 14), nrow=7, ncol=7, padding=0).cpu().permute(1,2,0)
image_grid = vutils.make_grid(image.cpu(), nrow=7, ncol=7).permute(1,2,0).mean(dim=-1)
ax[1].imshow(grid[:,:,0], cmap=plt.cm.rainbow, interpolation="bilinear")
ax[0].imshow(image_grid)
ax[0].set_axis_off()
ax[1].set_axis_off()
plt.show()

In [75]:
from torchmetrics.classification import MultilabelPrecisionRecallCurve
mlprc = MultilabelPrecisionRecallCurve(19, thresholds=100).to(device)
dataloader = datamodule.train_dataloader()

for batch in tqdm(dataloader):
    image, label = batch
    n_batch = image.shape[0]
    image = image.permute(0,2,1).view(n_batch, 49, -1, 224, 224).to(device)
    output = module.reduce_channel(image.reshape(-1, 4, 224, 224))
    output = module.model(output)
    output.logits = output.logits.reshape(n_batch, 49, -1)
    mlprc.update(torch.sigmoid(output.logits).mean(dim=1), label.to(device))

  0%|          | 0/4361 [00:00<?, ?it/s]

: 

: 

In [19]:
precision_, recall_, thresholds_ = mlprc.compute()

In [70]:
f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + torch.finfo(torch.float32).eps)
best_f1 = torch.max(f1_, dim=1)