In [1]:
import sys
sys.path.append('../..')

In [2]:
import os
import psutil

import random
import math
from functools import partial

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

import multiprocessing.dummy as mp

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger


from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw, WSIPatchesDummyDataloader
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss
from lib.trainers import WSIModuleV1

from lib.models.unetv1 import get_model
from lib.models.features_map import FeaturesMap

from sklearn.metrics import cohen_kappa_score

from tqdm.auto import tqdm

import matplotlib.pyplot as plt

In [3]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import (patches_clean90_csv_path as patches_csv_path, patches_path,
                          patches_clean90_pkl_path as patches_pkl_path)
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [4]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [5]:
from lib.dataloaders import WSIPatchesDataloader, WSIPatchesDatasetRaw
from lib.utils import get_pretrained_model, get_features

In [6]:
batch_size = 64

In [7]:
train_batch_path = '/mnt/SSDData/pdata/processed/pretrained/train/{}/'
test_batch_path = '/mnt/SSDData/pdata/processed/pretrained/val/'

train_loader = WSIPatchesDummyDataloader(train_batch_path, precalc_epochs=50, batch_size=batch_size, shuffle=True)
val_loader = WSIPatchesDummyDataloader(test_batch_path, precalc_epochs=50, batch_size=batch_size, shuffle=False)

In [8]:
steps_in_epoh = 1

epochs = 90

warmup_epochs = 0
warmup_steps = 0


hparams = {
    'batch_size': batch_size,
    'learning_rate': 0.001 * batch_size / 256,
    'dataset': {
        'dataloader': 'dummy',
        'rgb_mean': patches_rgb_mean_av1,
        'rgb_std': patches_rgb_std_av1,
        'classes': max_lbl_nums,
        'precalc_epochs': 11,
        'train_test_split': {},
    },
    'optimizer': {
        'name': 'torch.optim.Adam',
        'params': {
            'weight_decay': 1e-4
        }
    },
    'scheduler': {
        'name': 'torch.optim.lr_scheduler.ExponentialLR',
        'params': {
            'gamma': 0.96,
        },
        'interval': 'epoch'
    },
    'loss': {
        'weights': {
            'reg': 1 / 2,
            'class': 9 / 2
        },
        'label_smoothing': 0.1
    },
    'warmup_steps': warmup_steps,
    'steps_in_epoh': steps_in_epoh,
    'epochs': epochs
}

In [9]:
steps_in_epoh = len(train_loader)

In [10]:
steps_in_epoh

132

In [11]:
hparams['steps_in_batch'] = steps_in_epoh
if 'T_max' in hparams['scheduler']['params']:
    hparams['scheduler']['params']['T_max'] = (epochs * steps_in_epoh -
                                               warmup_steps)

In [12]:
# tmp[0].shape
# torch.Size([64, 300, 64, 8, 8])

In [13]:
class Dense_512x1x1(nn.Module):
    def __init__(self, classes, features_do):
        super().__init__()

        self.reduce = nn.Sequential(
            nn.Dropout1d(features_do) if features_do > 0 else nn.Identity(),
            nn.BatchNorm1d(512),
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 512, 1),
        )

        self.reg_linear = nn.Linear(512, 1)
        self.class_linear = nn.Linear(512, classes)

    def forward(self, features, ys, xs):
        b = features.shape[0]
        r_mask = ys > -1
        x = features.transpose(-1, -2)[r_mask][..., None]
        f_ns = torch.arange(b, device=features.device)[:, None].expand(b, ys.shape[1])[r_mask]
        
        out = self.reduce(x)
        out = [F.adaptive_max_pool1d(out[f_ns == i].T, 1)[..., 0] for i in range(f_ns.max()+1)]
        out = torch.cat(out)

        return self.reg_linear(out), self.class_linear(out)

In [14]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv1d(in_planes, out_planes, kernel_size=1, bias=False)

class ResBlock(nn.Module):
    def __init__(self, inplanes, planes):
        super().__init__()
        norm_layer = nn.BatchNorm1d        
        
        self.conv1 = conv3x3(inplanes, planes)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        
    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

In [15]:
class ResNet1D_512x1x1(nn.Module):
    def __init__(self, classes, features_do):
        super().__init__()

        self.reduce = nn.Sequential(
            nn.BatchNorm1d(512),
            
            ResBlock(512, 512),
            ResBlock(512, 512),
            ResBlock(512, 512),
        )

        self.reg_linear = nn.Linear(512, 1)
        self.class_linear = nn.Linear(512, classes)

    def forward(self, features, ys, xs):
        b = features.shape[0]
        r_mask = ys > -1
        x = features.transpose(-1, -2)[r_mask][..., None]
        f_ns = torch.arange(b, device=features.device)[:, None].expand(b, ys.shape[1])[r_mask]
        
        out = self.reduce(x)
        out = [F.adaptive_max_pool1d(out[f_ns == i].T, 1)[..., 0] for i in range(f_ns.max()+1)]
        out = torch.cat(out)

        return self.reg_linear(out), self.class_linear(out)

In [16]:
from lib.models.wsi_resnets import Resnet_512x1x1

In [17]:
model = Resnet_512x1x1('resnet18', 512, 6, 0)



In [18]:
#tmp = torch.load('./lightning_logs/version_1/checkpoints/last.ckpt')

#module = nn.Sequential()
#module.add_module('model', model)

# module.load_state_dict(tmp['state_dict']);

In [19]:
features, ys, xs, provider, isup_grade, gleason_score = next(iter(train_loader))

In [20]:
'''
class SaveOutput:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []
        
save_output = SaveOutput()

hook_handles = []

for layer in model.modules():
    if isinstance(layer, torch.nn.modules.conv.Conv2d):
        handle = layer.register_forward_hook(save_output)
        hook_handles.append(handle)
''';        

In [21]:
#device = torch.device('cuda:0')

In [22]:
#with torch.no_grad():
#    # x, f_mask = model(features, ys, xs)
#    tmp = model(features.to(device), ys.to(device), xs.to(device))

In [23]:
#x.shape, f_mask.shape

In [24]:
#xs = torch.stack([x[b].permute(1, 2, 0)[f_mask[b]].mean(0)
#                  for b in range(x.shape[0])])

In [25]:
#plt.imshow(x[3, 0]);

In [26]:
# [o.shape for o in save_output.outputs]

In [27]:
# model.f_map.backend_feature = model.f_map.backend_feature.to(torch.device('cuda:0'));

In [28]:
# features.shape

In [29]:
# np.save("val_lbl_data.npy", preds.cpu().numpy())

In [30]:
module = WSIModuleV1(model, hparams, log_train_every_batch=False)

In [31]:
#module.load_from_checkpoint('./lightning_logs/version_3/checkpoints/last.ckpt', 
#                            hparams=hparams)

In [32]:
trainer = Trainer(max_epochs=hparams['epochs'], gpus=[0,], fast_dev_run=False, num_sanity_val_steps=0)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(module, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(trainer.checkpoint_callback.dirpath,
                                     "last.ckpt"))

INFO:lightning:
   | Name                                 | Type           | Params
--------------------------------------------------------------------
0  | model                                | Resnet_512x1x1 | 2 M   
1  | model.f_map                          | FeaturesMap    | 512   
2  | model.backbone                       | ResNet         | 1 M   
3  | model.backbone.conv1                 | Sequential     | 174 K 
4  | model.backbone.conv1.0               | Identity       | 0     
5  | model.backbone.conv1.1               | BatchNorm2d    | 1 K   
6  | model.backbone.conv1.2               | Conv2d         | 131 K 
7  | model.backbone.conv1.3               | BatchNorm2d    | 512   
8  | model.backbone.conv1.4               | ReLU           | 0     
9  | model.backbone.conv1.5               | Conv2d         | 32 K  
10 | model.backbone.conv1.6               | BatchNorm2d    | 256   
11 | model.backbone.conv1.7               | ReLU           | 0     
12 | model.backbone.conv1.8    

[Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.00025
    lr: 0.00025
    weight_decay: 0.0001
)] [{'interval': 'epoch', 'scheduler': <lib.schedulers.DelayedScheduler object at 0x7faace9f58d0>}]




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…