In [1]:
import sys
import os

import random
import math

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

sys.path.append('../..')
from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss

from lib.models.unetv1 import get_model

from sklearn.metrics import cohen_kappa_score

In [2]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import patches_clean90_csv_path as patches_csv_path, patches_path
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [3]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

In [4]:
len(train_img_ids)

8420

In [5]:
test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [6]:
class WSIPatchesDataset1D(WSIPatchesDatasetRaw):
    def __init__(self, image_ids, csv_path=patches_csv_path,
                 path=patches_path, scale=1, transform=None, max_len=300):
        super().__init__(image_ids, csv_path, patches_path, scale, transform)
        self.max_len = max_len

    def __getitem__(self, idx):
        def trim(x):
            return x[:self.max_len]
        
        def pad(x):
            return np.pad(x[:self.max_len], ((0, self.max_len-p),)+((0, 0),)*(len(x.shape)-1), constant_values=-1)
        
        imgs, ys, xs, provider, isup_grade, gleason_score =\
            super().__getitem__(idx)
        
        p = imgs.shape[0]
        if p > self.max_len:
            imgs, ys, xs = trim(imgs), trim(ys), trim(xs)
        elif p < self.max_len:
            imgs, ys, xs = pad(imgs), pad(ys), pad(xs)
        else:
            pass

        return imgs, ys, xs, provider, isup_grade, gleason_score
    
    
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

In [7]:
max_len = 300

In [8]:
rgb_mean, rgb_std = (torch.tensor(patches_rgb_mean_av1, dtype=torch.float32), 
                     torch.tensor(patches_rgb_std_av1, dtype=torch.float32))

In [9]:
# imgs, ys, xs, provider, isup_grade, gleason_score = next(iter(train_wsipatches_dataset))

In [10]:
patches_device = torch.device('cuda:0')
# patches_device = torch.device('cpu')
main_device = torch.device('cuda:1')

In [11]:
tmp = torch.load("../Patches256TestRun/version_0/checkpoints/last.ckpt", map_location=patches_device)

In [12]:
# model = get_model(actual_lbl_nums, decoder=False, labels=False)
model = get_model(actual_lbl_nums)

In [13]:
module = nn.Sequential()

In [14]:
module.add_module('model', model)

In [15]:
module.to(patches_device);

In [16]:
module.load_state_dict(tmp['state_dict'])

<All keys matched successfully>

In [17]:
model.segmentation = False
model.classification_head = None
model.autodecoder = None

In [18]:
module.eval();

In [19]:
rgb_mean = rgb_mean.to(patches_device)
rgb_std = rgb_std.to(patches_device)

In [20]:
from time import time

In [21]:
features_batch_size = 128

In [22]:
def get_features(imgs):
    imgs = imgs.to(patches_device)
    n_imgs = (imgs - rgb_mean) / rgb_std
    
    b_features = []
    for b in range(0, n_imgs.shape[0], features_batch_size):
        with torch.no_grad():
            features, *_ = model(n_imgs[b:b+features_batch_size], return_features=True)
            b_features.append(features)

    features = torch.cat(b_features, dim=0)
    
    return features

In [23]:
class MainBatchGenerator1D:
    def __init__(self, img_ids, batch_size, patches_batch_size, shuffle, num_workers, max_len, 
                 patches_csv_path, scale, transform):
        self.batch_size = batch_size
        self.patches_batch_size = patches_batch_size
        
        train_wsipatches_dataset = WSIPatchesDataset1D(img_ids, patches_csv_path, 
                                                       patches_path, scale=scale, 
                                                       transform=transform,
                                                       max_len=max_len)

        self.train_loader = torch.utils.data.DataLoader(
            train_wsipatches_dataset,
            batch_size=patches_batch_size, shuffle=shuffle,
            num_workers=num_workers, pin_memory=True,
        )
        
    def __len__(self):
        return math.ceil(len(self.train_loader) * self.patches_batch_size / self.batch_size)
        
    def __iter__(self):
        patches_iter = iter(self.train_loader)
        stop_iteration = False
        
        while not stop_iteration:
            main_features = []
            main_isup_grade = []
            for _ in range(0, self.batch_size, self.patches_batch_size):
                try:
                    imgs, ys, xs, provider, isup_grade, gleason_score = next(patches_iter)
                except StopIteration:
                    print("Epoch end")
                    stop_iteration = True
                    break

                real_mask = ys != -1
                features = torch.zeros((imgs.shape[0], max_len, 512, 8, 8), 
                                       dtype=torch.float32, device=patches_device)
                if real_mask.any():
                    imgs = imgs[real_mask]   

                    real_features = get_features(imgs)
                    features[real_mask] = real_features

                #main_features.append(features.cpu())
                #main_isup_grade.append(isup_grade.cpu())
                main_features.append(features)
                main_isup_grade.append(isup_grade)
                
            main_features = torch.cat(main_features)
            main_isup_grade = torch.cat(main_isup_grade)

            yield main_features, main_isup_grade

In [24]:
f_d_rate = 0.5
d_rate = 0.4

main_model = nn.Sequential(
    #nn.Dropout(f_d_rate),
    
    nn.Conv1d(512, 64, 1),
    #nn.Conv1d(512*8*8, 64, 1),
    nn.ReLU(inplace=True),
    #nn.Dropout(d_rate),
    nn.BatchNorm1d(64),
    
    nn.Conv1d(64, 64, 1),
    nn.ReLU(inplace=True),
    #nn.Dropout(d_rate),
    nn.BatchNorm1d(64),
    
    nn.AdaptiveMaxPool1d(1),
    LambdaLayer(lambda x: x.view(-1, 64)),
    
    nn.Linear(64, 64),
    nn.ReLU(inplace=True),
    #nn.Dropout(d_rate),
    nn.BatchNorm1d(64),
    
    nn.Linear(64, max_lbl_nums)
).to(main_device)

In [25]:
sum([p.data.numel() for p in main_model.parameters()])

41926

In [26]:
optimizer = optim.Adam(main_model.parameters(), lr=0.001, weight_decay=0)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0, last_epoch=-1)

criterion = SmoothLoss(lambda o, t: F.kl_div(o, t), smoothing=0.1, one_hot_target=True)

In [27]:
model.training

False

In [28]:
from tqdm.auto import tqdm

In [29]:
# for e in range(1, 100):
# save_batch_path = f'/mnt/SSDData/pdata/processed/tmp{e}/'
save_batch_path = f'/mnt/SSDData/pdata/processed/tmp/tmp_wo_aug/'
os.makedirs(save_batch_path, exist_ok=True)

train_loader = MainBatchGenerator1D(train_img_ids, batch_size=64, patches_batch_size=2, shuffle=True, 
                                    num_workers=6, max_len=300, patches_csv_path=patches_csv_path,
                                    scale=0.5, transform=augment_empty_clr_only)
                                    # scale=0.5, transform=augment_v1_clr_only)

n = 0
for features, target in tqdm(train_loader, total=len(train_loader)):
    features = features.mean(dim=-1).mean(dim=-1).transpose(1, -1)
    torch.save(features, os.path.join(save_batch_path, f"features_{n}.pth"))
    torch.save(target, os.path.join(save_batch_path, f"target_{n}.pth"))    
    n += 1

del train_loader

HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))

Epoch end



In [29]:
        coefficients = optimized_rounder.coefficients()
        final_preds = optimized_rounder.predict(preds, coefficients)

In [30]:
optimized_rounder = OptimizedRounder()

HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))

Epoch end



In [24]:
model.training

False

In [25]:
save_batch_path = '/mnt/SSDData/pdata/processed/val_tmp/'

In [26]:
os.makedirs(save_batch_path, exist_ok=True)

In [27]:
val_loader = MainBatchGenerator1D(test_img_ids, batch_size=64, patches_batch_size=2, shuffle=False, 
                                  num_workers=8, max_len=300, patches_csv_path=patches_csv_path,
                                  scale=0.5, transform=augment_empty_clr_only)

In [28]:
save_batch_path

'/mnt/SSDData/pdata/processed/val_tmp/'

In [29]:
os.makedirs(save_batch_path, exist_ok=True)

In [32]:
n = 0
for features, target in tqdm(val_loader, total=len(val_loader)):
    features = features.mean(dim=-1).mean(dim=-1).transpose(1, -1)
    torch.save(features, os.path.join(save_batch_path, f"features_{n}.pth"))
    torch.save(target, os.path.join(save_batch_path, f"target_{n}.pth"))    
    n += 1

HBox(children=(FloatProgress(value=0.0, max=33.0), HTML(value='')))

Epoch end



In [35]:
64 * 300 * 512 * 8 * 8 * 4

2516582400

In [36]:
2.5 * 132

330.0

In [37]:
!df -H | grep mnt

/dev/nvme0n1p1  503G  286G  192G  60% /mnt/SSDData
/dev/sda1       4.1T  3.7T  351G  92% /mnt/HDDData


In [29]:
n = 0
for epoch in range(10):
    for features, target in tqdm(train_loader, total=len(train_loader)):
        b = features.shape[0]

        features, target = features.to(main_device), target.to(main_device)
        
        # features = features.view(b, max_len, -1).transpose(1, -1)
        features = features.mean(dim=-1).mean(dim=-1).transpose(1, -1)

        optimizer.zero_grad()

        output = main_model(features)
        output = F.log_softmax(output, dim=-1)
        loss = criterion(output, target)    

        loss.backward()
        # torch.nn.utils.clip_grad_norm_(main_model.parameters(), 0.01)
        optimizer.step()
        lr_scheduler.step()

        acc = output.argmax(dim=-1).eq(target).float().mean()
        
        o_output = output.argmax(dim=-1)
        qwk = cohen_kappa_score(o_output.cpu().numpy(), target.cpu().numpy(), weights="quadratic")

        print(f"{n}. loss: {loss.item():.4f}, acc: {acc.item():.4f}, qwk: {qwk:.4f}, lr: {optimizer.param_groups[0]['lr']:.5f}")

        n += 1

HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))



0. loss: 0.2220, acc: 0.1875, qwk: 0.3540, lr: 0.00100
1. loss: 0.2414, acc: 0.1719, qwk: 0.0078, lr: 0.00100
2. loss: 0.2180, acc: 0.2969, qwk: 0.2317, lr: 0.00100
3. loss: 0.2230, acc: 0.2812, qwk: 0.2266, lr: 0.00100
4. loss: 0.2267, acc: 0.2188, qwk: 0.2150, lr: 0.00100
5. loss: 0.2156, acc: 0.2188, qwk: 0.2237, lr: 0.00100
6. loss: 0.2173, acc: 0.3438, qwk: 0.4303, lr: 0.00100
7. loss: 0.2315, acc: 0.1719, qwk: -0.0081, lr: 0.00100
8. loss: 0.2201, acc: 0.2812, qwk: -0.0116, lr: 0.00100
9. loss: 0.2223, acc: 0.2812, qwk: 0.3463, lr: 0.00100
10. loss: 0.1905, acc: 0.3750, qwk: 0.4988, lr: 0.00100
11. loss: 0.2019, acc: 0.3906, qwk: 0.2900, lr: 0.00100
12. loss: 0.2017, acc: 0.3906, qwk: 0.4337, lr: 0.00100
13. loss: 0.1885, acc: 0.4062, qwk: 0.3425, lr: 0.00100



KeyboardInterrupt: 

In [None]:
#main_batch_size = 64

#patches_batch_size = 2
#num_workers = 11

In [None]:
output, target

In [44]:
o_output = output.argmax(dim=-1)

In [45]:
o_output

tensor([1, 0, 3, 5, 1, 1, 2, 1, 0, 0, 1, 0, 5, 4, 4, 4, 5, 1, 3, 0, 2, 3, 0, 0,
        1, 3, 0, 0, 1, 1, 1, 1, 0, 1, 1, 3, 4, 3, 0, 4, 1, 0, 2, 1, 5, 5, 1, 3,
        0, 0, 5, 1, 0, 1, 5, 5, 3, 1, 1, 0, 0, 0, 0, 4], device='cuda:1')

In [46]:
target

tensor([5, 0, 3, 4, 1, 4, 3, 0, 0, 0, 1, 3, 5, 3, 4, 5, 5, 1, 4, 0, 3, 5, 0, 1,
        0, 1, 0, 0, 1, 2, 2, 1, 0, 2, 1, 3, 3, 4, 0, 0, 1, 2, 1, 1, 5, 5, 4, 4,
        1, 5, 5, 2, 0, 1, 3, 5, 5, 1, 1, 0, 4, 0, 0, 4], device='cuda:1')

In [49]:
cohen_kappa_score(o_output.cpu().numpy(), target.cpu().numpy(), weights="quadratic")

0.6770040959625512

In [None]:
oh_target = torch.tensor((b, max_lbl_nums), device=main_device)
oh_o_output = torch.tensor((b, max_lbl_nums), device=main_device)

In [None]:
o_output

In [None]:
oh_target[target] = 1
oh_o_output[o_output] = 1

In [52]:
cm = torch.matmul(oh_target[:, :, None], oh_o_output[:, None, :]).sum(0)

In [None]:
@dataclass
class KappaScore(ConfusionMatrix):
    "Computes the rate of agreement (Cohens Kappa)."
    weights:Optional[str]=None      # None, `linear`, or `quadratic`

    def on_epoch_end(self, last_metrics, **kwargs):
        sum0 = self.cm.sum(dim=0)
        sum1 = self.cm.sum(dim=1)
        expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()
        if self.weights is None:
            w = torch.ones((self.n_classes, self.n_classes))
            w[self.x, self.x] = 0
        elif self.weights == "linear" or self.weights == "quadratic":
            w = torch.zeros((self.n_classes, self.n_classes))
            w += torch.arange(self.n_classes, dtype=torch.float)
            w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2
        else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".')
        k = torch.sum(w * self.cm) / torch.sum(w * expected)
        return add_metrics(last_metrics, 1-k)

In [31]:
#empty_mask = (main_features.view(b, max_len, -1) == 0).all(dim=-1)[..., None, None, None]
#dummy_feature = nn.Parameter(torch.ones((1, 1, 512, 8, 8), dtype=torch.float32, device=main_device))
#main_features = main_features + empty_mask * dummy_feature.expand(b, 1, 512, 8, 8)

In [None]:
# 8*55 - 200 сек

In [None]:
# imgs, ys, xs, provider, isup_grade, gleason_score

In [None]:
import matplotlib.pyplot as plt

In [None]:
train_wsipatches_dataset_iter = iter(train_wsipatches_dataset)

In [None]:
fig, axs = plt.subplots(6, 6, figsize=(18, 18))
axs = axs.ravel()
n = 0
for imgs, ys, xs, provider, isup_grade, gleason_score in train_wsipatches_dataset_iter:
    imgs = imgs.transpose([0, 2, 3, 1])
    height = ys.max() + 1
    width = xs.max() + 1
    wsi_img = np.zeros((height, width, 3), dtype=np.float32)
    for y, x, img in zip(ys, xs, imgs):
        wsi_img[y, x] = img.reshape(-1, 3).mean(0)
        
    axs[n].imshow(wsi_img);
    axs[n].set_title(f"{n}: {provider}")
    
    n += 1
    
    if n == 36:
        break