In [1]:
import sys
sys.path.insert(0, "../")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import pytorch_lightning as pl

from tqdm import tqdm
from omegaconf import OmegaConf
from cai.plmodule.module import HPASSCModule
from cai.pldatamodule.datamodule import HPASSCDataModule

pl.seed_everything(42, workers=True)

Global seed set to 42


42

In [3]:
#ckpt = "../models/epoch=49-step=3400.ckpt"
#cfg = OmegaConf.to_container(OmegaConf.load("/mnt/yando/Users/yando/Experiments/cai/outputs/2022-12-12/23-25-29/.hydra/config.yaml"))

#ckpt = "../models/epoch=31-step=2048.ckpt"
#cfg = OmegaConf.to_container(OmegaConf.load("/mnt/yando/Users/yando/Experiments/cai/multirun/2022-12-13/16-52-16/0/.hydra/config.yaml"))

ckpt = "../models/epoch=26-step=1836.ckpt"
cfg = OmegaConf.to_container(OmegaConf.load("/mnt/yando/Users/yando/Experiments/cai/multirun/2022-12-13/00-17-30/0/.hydra/config.yaml"))

In [4]:
inference_transforms = [
        [
            {
                "module" : "tv_transforms",
                 "name" : "CenterCrop",
                 "kwargs" : {
                     "size" : [1568, 1568]
                 }
            },
            {
                "module" : "tv_transforms",
                "name" : "Normalize",
                "kwargs" : {
                    "mean" : [0.06438801, 0.0441467, 0.03966651, 0.06374957],
                    "std" : [0.10712028, 0.08619478, 0.11134183, 0.10635688]
                }
            },
            {
                "module" : "tv_transforms",
                "name" : "Resize",
                "kwargs" : {
                    "size" : [224, 224]
                }
            }
        ],
        [
            {
                "module" : "tv_transforms",
                 "name" : "CenterCrop",
                 "kwargs" : {
                     "size" : [1568, 1568]
                 }
            },
            {
                "module" : "tv_transforms",
                "name" : "Normalize",
                "kwargs" : {
                    "mean" : [0.06438801, 0.0441467, 0.03966651, 0.06374957],
                    "std" : [0.10712028, 0.08619478, 0.11134183, 0.10635688]
                }
            },
            {
                "module" : "nn",
                "name" : "Unfold",
                "kwargs" : {
                    "kernel_size" : [224,224],
                    "stride" : 224
                }
            }
        ]
]

In [5]:
#cfg["module"]["loss_fn"]["module"] = "nn"
cfg["datamodule"]["dataset"]["transforms"] = inference_transforms[1]
cfg["datamodule"]["dataloader"]["batch_size"]=3
cfg["datamodule"]["dataloader"]["num_workers"]=15
cfg["datamodule"]["dataloader"]["prefetch_factor"]=5
cfg["datamodule"]["use_wr_sampler"] = False
del cfg["datamodule"]["dataloader"]["sampler"]

In [6]:
device = "cuda:0"
module = HPASSCModule.load_from_checkpoint(ckpt, **cfg["module"]).to(device)
datamodule = HPASSCDataModule(**cfg["datamodule"])

In [7]:
module.eval()
datamodule.setup()

In [None]:
image, label = datamodule.train_dataset[10]

In [24]:
image = image.permute(1,0).view(49, -1, 224, 224) if len(image.shape) < 3 else image[None, :, :, :]

In [25]:
image = image.to(device)
output = module.reduce_channel(image)
output = module.model(output, output_attentions=True)
output.logits.shape, len(output.attentions), output.attentions[0].shape

(torch.Size([49, 19]), 12, torch.Size([49, 12, 197, 197]))

In [35]:
torch.softmax(output.logits, dim=1).max(dim=0).values.sort(descending=True)

torch.return_types.sort(
values=tensor([0.2268, 0.1832, 0.1250, 0.0762, 0.0750, 0.0626, 0.0621, 0.0618, 0.0609,
        0.0597, 0.0580, 0.0559, 0.0523, 0.0509, 0.0492, 0.0463, 0.0430, 0.0417,
        0.0405], device='cuda:0', grad_fn=<SortBackward0>),
indices=tensor([ 0, 16,  8, 13, 10, 17, 15,  5, 14,  4,  6,  3,  7,  1,  2, 12,  9, 18,
        11], device='cuda:0'))

In [None]:
n_tiles, nh, _, _ = output.attentions[-1].shape
attentions = output.attentions[-1][:, :, 0, 1:]
attentions = (attentions * (attentions > 0.003)).view(n_tiles, nh, -1)
for head in range(nh):
    fig, ax = plt.subplots(1, 2)
    grid = vutils.make_grid(attentions[:, head,:].view(-1, 1, 14, 14), nrow=7, ncol=7, padding=0, normalize=True)
    image_grid = vutils.make_grid(image, nrow=7, ncol=7).permute(1,2,0).mean(dim=-1).cpu()
    grid = transforms.functional.resize(grid, (1568, 1568)).cpu().permute(1,2,0)
    ax[1].imshow(grid[:,:,0], cmap=plt.cm.rainbow)
    ax[0].imshow(image_grid)
    ax[0].set_axis_off()
    ax[1].set_axis_off()
    plt.show()

In [None]:
from itertools import product
from tqdm import tqdm
num_classes = 19
num_layers = 12
num_tiles = 49
grad = torch.zeros((num_layers, *output.attentions[-1].shape)).to(device)
for cla, layer, tile in tqdm(product(range(num_classes), range(num_layers), range(num_tiles)), total=num_classes * num_layers * num_tiles):
    #grad[cla, layer] += torch.autograd.grad(output.logits[tile, cla], output.attentions[layer], retain_graph=True)[0] 
    pass

In [None]:
threshold = 0.0001
n_tiles, nh, _, _ = output.attentions[-1].shape
attentions = (output.attentions[-1] * grad)[:, :, 0, 1:]
attentions = (attentions * (attentions > threshold)).view(n_tiles, nh, -1)

fig, ax = plt.subplots(1, 2)
grid = vutils.make_grid(attentions.max(dim=1).values.view(-1, 1, 14, 14), nrow=7, ncol=7, padding=0).cpu().permute(1,2,0)
image_grid = vutils.make_grid(image.cpu(), nrow=7, ncol=7).permute(1,2,0).mean(dim=-1)
ax[1].imshow(grid[:,:,0], cmap=plt.cm.rainbow, interpolation="bilinear")
ax[0].imshow(image_grid)
ax[0].set_axis_off()
ax[1].set_axis_off()
plt.show()

In [13]:
from torchmetrics.classification import MultilabelPrecisionRecallCurve, MultilabelF1Score
mlprc = MultilabelPrecisionRecallCurve(19, thresholds=100).to(device)
dataloader = datamodule.train_dataloader()

for batch in tqdm(dataloader):
    image, label = batch
    n_batch = image.shape[0]
    image = image.permute(0,2,1).view(n_batch, 49, -1, 224, 224).to(device)
    output = module.reduce_channel(image.reshape(-1, 4, 224, 224))
    output = module.model(output)
    output.logits = output.logits.reshape(n_batch, 49, -1)
    mlprc.update(torch.sigmoid(output.logits).max(dim=1).values, label.to(device))

  4%|▍         | 261/5815 [03:44<1:19:28,  1.16it/s]


KeyboardInterrupt: 

In [14]:
precision_, recall_, thresholds_ = mlprc.compute()
f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + torch.finfo(torch.float32).eps)
best_f1 = torch.max(f1_, dim=1)
best_arg = torch.argmax(f1_,dim=1)
print(best_f1, thresholds_[best_arg.T])
threshold = thresholds_[best_arg.T]

torch.return_types.max(
values=tensor([0.5676, 0.0920, 0.1731, 0.2080, 0.1830, 0.1724, 0.1395, 0.1910, 0.1884,
        0.2017, 0.1416, 0.0000, 0.1699, 0.3972, 0.1931, 0.1250, 0.5537, 0.0839,
        0.0070], device='cuda:0'),
indices=tensor([97, 87, 61, 93, 93, 95, 90, 92, 86, 94, 95,  0, 90, 98, 95, 87, 98, 85,
        67], device='cuda:0')) tensor([0.9798, 0.8788, 0.6162, 0.9394, 0.9394, 0.9596, 0.9091, 0.9293, 0.8687,
        0.9495, 0.9596, 0.0000, 0.9091, 0.9899, 0.9596, 0.8788, 0.9899, 0.8586,
        0.6768], device='cuda:0')


  """


In [23]:
precision_.clone().cpu()

tensor([[ 5.6052e-45,  0.0000e+00,  0.0000e+00,  ...,  7.9569e+00,
          4.5814e-41, -8.1149e-26],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  4.5815e-41,
          7.9569e+00,  4.5814e-41],
        [ 0.0000e+00,  0.0000e+00,         nan,  ..., -8.1145e-26,
          4.5815e-41,  7.9569e+00],
        ...,
        [ 4.5837e-38,  0.0000e+00,  2.8026e-45,  ...,  7.9569e+00,
          4.5814e-41, -8.1151e-26],
        [ 0.0000e+00,  2.8026e-45,  0.0000e+00,  ...,  4.5815e-41,
          7.9569e+00,  4.5814e-41],
        [ 2.8026e-45,  0.0000e+00,  0.0000e+00,  ..., -8.1147e-26,
          4.5815e-41,  7.9569e+00]])

In [8]:
temp = (torch.tensor([0.4042, 0.0560, 0.1129, 0.0568, 0.0651, 0.0822, 0.0477, 0.0850, 0.0549,
         0.0596, 0.0629, 0.0033, 0.0813, 0.1860, 0.0923, 0.0192, 0.3265, 0.0273,
         0.0015]),
 torch.tensor([1.0000, 0.3432, 1.0000, 1.0000, 1.0000, 1.0000, 0.3867, 1.0000, 0.4852,
         0.3834, 0.2565, 1.0000, 1.0000, 0.5404, 1.0000, 0.5309, 0.7257, 0.4401,
         1.0000]),
 torch.tensor([0.5757, 0.0962, 0.2029, 0.1075, 0.1223, 0.1519, 0.0850, 0.1567, 0.0987,
         0.1032, 0.1010, 0.0066, 0.1505, 0.2768, 0.1690, 0.0371, 0.4504, 0.0515,
         0.0031]),
 torch.tensor([0.0000, 0.0101, 0.0000, 0.0000, 0.0000, 0.0000, 0.0202, 0.0000, 0.0202,
        0.0101, 0.0505, 0.0000, 0.0000, 0.0202, 0.0000, 0.0202, 0.0202, 0.0101,
        0.0000]))

In [13]:
precision_, recall_, thresholds_ = mlprc.compute()
f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + torch.finfo(torch.float32).eps)
best_f1 = torch.max(f1_, dim=1)
best_arg = torch.argmax(f1_,dim=1)

In [11]:
threshold = torch.tensor([0.9798, 0.8788, 0.6162, 0.9394, 0.9394, 0.9596, 0.9091, 0.9293, 0.8687,
        0.9495, 0.9596, 0.0000, 0.9091, 0.9899, 0.9596, 0.8788, 0.9899, 0.8586,
        0.6768]).to(device)

In [12]:
from torchmetrics.classification import MultilabelPrecisionRecallCurve, MultilabelF1Score
mlprc = MultilabelF1Score(19, average='none').to(device)
dataloader = datamodule.val_dataloader()

for batch in tqdm(dataloader):
    image, label = batch
    n_batch = image.shape[0]
    image = image.permute(0,2,1).view(n_batch, 49, -1, 224, 224).to(device)
    output = module.reduce_channel(image.reshape(-1, 4, 224, 224))
    output = module.model(output)
    output.logits = output.logits.reshape(n_batch, 49, -1)
    logits = torch.sigmoid(output.logits).max(dim=1).values
    mlprc.update((logits > threshold).to(torch.int), label.to(device))

 23%|██▎       | 335/1453 [03:55<13:05,  1.42it/s]  


KeyboardInterrupt: 

In [13]:
torch.tensor([precision_[i, best_f1.indices[i]].cpu().item() for i in range(19)]), torch.tensor([recall_[i, best_f1.indices[i]].cpu().item() for i in range(19)]),best_f1.values

NameError: name 'precision_' is not defined

In [14]:
ret = mlprc.compute()

In [15]:
ret

tensor([0.5946, 0.1198, 0.1805, 0.1471, 0.1731, 0.2012, 0.0702, 0.1746, 0.0503,
        0.2222, 0.0945, 0.0099, 0.1525, 0.3878, 0.1312, 0.1818, 0.5328, 0.0494,
        0.0027], device='cuda:0')