In [1]:
import sys
sys.path.append('../..')

In [2]:
import os
import psutil

import random
import math
from functools import partial

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

import multiprocessing.dummy as mp

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger


from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw, WSIPatchesDummyDataloader
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss
from lib.trainers import WSIModuleV1

from lib.models.unetv1 import get_model
from lib.models.features_map import FeaturesMap

from sklearn.metrics import cohen_kappa_score

from tqdm.auto import tqdm

import matplotlib.pyplot as plt

In [3]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import (patches_clean90_csv_path as patches_csv_path, patches_path,
                          patches_clean90_pkl_path as patches_pkl_path)
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [4]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [5]:
from lib.dataloaders import WSIPatchesDataloader, WSIPatchesDatasetRaw
from lib.utils import get_pretrained_model, get_features

In [6]:
batch_size = 16

In [7]:
#train_batch_path = '/mnt/HDDData/pdata/processed/pretrained_64x8x8/train/{}/'
#test_batch_path = '/mnt/HDDData/pdata/processed/pretrained_64x8x8/val/'

train_batch_path = '/mnt/SSDData/pdata/processed/pretrained_64x8x8/train/{}/'
test_batch_path = '/mnt/SSDData/pdata/processed/pretrained_64x8x8/val/'

train_loader = WSIPatchesDummyDataloader(train_batch_path, precalc_epochs=6, batch_size=batch_size, shuffle=True)
val_loader = WSIPatchesDummyDataloader(test_batch_path, precalc_epochs=6, batch_size=batch_size, shuffle=False)

In [8]:
steps_in_epoh = 1

epochs = 90

warmup_epochs = 0
warmup_steps = 0

hparams = {
    'batch_size': batch_size,
    'learning_rate': 0.001 * 32 / 256,
    'dataset': {
        'dataloader': 'dummy',
        'rgb_mean': patches_rgb_mean_av1,
        'rgb_std': patches_rgb_std_av1,
        'classes': max_lbl_nums,
        'precalc_epochs': 50,
        'train_test_split': {},
    },
    'optimizer': {
        'name': 'torch.optim.Adam',
        'params': {
            'weight_decay': 1e-4
        }
    },
    'scheduler': {
        'name': 'lib.schedulers.ExponentialLRWithMin',
        'params': {
            'gamma': 0.92,
            'eta_min': 1.25e-5
        },
        'interval': 'epoch'
    },
    'loss': {
        'weights': {
            'reg': 1 / 2,
            'class': 9 / 2
        },
        'label_smoothing': 0.1
    },
    'warmup_steps': warmup_steps,
    'steps_in_epoh': steps_in_epoh,
    'epochs': epochs
}

In [9]:
hparams['steps_in_batch'] = steps_in_epoh
if 'T_max' in hparams['scheduler']['params']:
    hparams['scheduler']['params']['T_max'] = (epochs * steps_in_epoh -
                                               warmup_steps)

In [10]:
# tmp[0].shape
# torch.Size([64, 300, 64, 8, 8])

In [11]:
class TiledFeaturesMap(nn.Module):
    def __init__(self, f_channels=512, f_size=1,
                 t_sz=9, t_step=6, t_cut=2):
        super().__init__()

        self.f_size = f_size
        self.f_channels = f_channels
        self.t_sz = t_sz
        self.t_step = t_step
        self.t_cut = t_cut

    def forward(self, features, ys, xs, validation=None):
        if validation is None:
            validation = not self.training

        f_tiles = []
        f_ns = []

        for b in range(features.shape[0]):
            y_min, x_min = ys[b].min(), xs[b].min()
            y_max, x_max = ys[b].max(), xs[b].max()
            
            if not validation:
                y_rnd = random.randint(0, self.t_step)
                x_rnd = random.randint(0, self.t_step)   
            else:
                y_rnd, x_rnd = 0, 0

            r_mask = ys[b] > -1

            x_map = torch.zeros((y_max-y_min+1+self.t_sz+y_rnd, 
                                 x_max-x_min+1+self.t_sz+x_rnd,
                                 self.f_channels, self.f_size, self.f_size),
                                dtype=features.dtype, device=features.device)

            x_map[ys[b, r_mask]-y_min+y_rnd, xs[b, r_mask]-x_min+x_rnd] = features[b, r_mask]

            x_tiles = x_map.unfold(0, self.t_sz, self.t_step).unfold(1, self.t_sz, self.t_step)

            f_t_idxs = x_tiles[..., self.t_cut:-self.t_cut,
                               self.t_cut:-self.t_cut].reshape(x_tiles.shape[:2]+(-1,)).sum(-1)

            f_tiles.append(x_tiles[f_t_idxs > 0])
            f_ns.extend([b,]*(f_t_idxs > 0).sum().item())

        f_tiles = torch.cat(f_tiles, dim=0)
        f_ns = torch.tensor(f_ns)

        f_tiles = f_tiles.permute(0, 1, 4, 2, 5, 3).reshape(f_tiles.shape[:2] +
                                                            (self.t_sz*self.f_size, 
                                                             self.t_sz*self.f_size))

        if not validation:
            for n in range(len(f_tiles)):
                f_tile = f_tiles[n]
                if random.random() > 0.5:
                    f_tile = torch.flip(f_tile, [-1])

                if random.random() > 0.5:
                    f_tile = torch.flip(f_tile, [-2])

                if random.random() > 0.5:
                    f_tile = f_tile.transpose(-1, -2)
                f_tiles[n] = f_tile

        return f_ns, f_tiles

In [12]:
import torchvision.models as models

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

f_d_rate = 0.0
d_rate = 0.0

class MainModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.tf_map = TiledFeaturesMap(f_channels=64, f_size=8, t_sz=9, t_step=6, t_cut=2)
        
        self.backbone = models.resnet18(pretrained=False)
        # self.backbone.conv1 = nn.Identity()

        self.backbone.conv1 = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 1),
        )
    
        self.backbone.fc = nn.Linear(512, 512)
        self.backbone.maxpool = nn.Identity()
        
        self.reg_linear = nn.Linear(512, 1)
        self.class_linear = nn.Linear(512, max_lbl_nums)

    def forward(self, features, ys, xs):
        f_ns, f_tiles = self.tf_map(features, ys, xs)

        b_out = self.backbone(f_tiles)

        out = [F.adaptive_max_pool1d(b_out[f_ns == i].T[None, ...], 1)[..., 0]
               for i in range(f_ns.max()+1)]
        out = torch.cat(out)
        return self.reg_linear(out), self.class_linear(out)

model = MainModel()

In [13]:
#summary(model.backbone, (64, 70*8, 40*8), device='cpu')

In [14]:
module = WSIModuleV1(model, hparams, log_train_every_batch=False)

In [15]:
trainer = Trainer(max_epochs=hparams['epochs'], gpus=[0,], fast_dev_run=False, num_sanity_val_steps=0)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(module, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(trainer.checkpoint_callback.dirpath,
                                     "last.ckpt"))

INFO:lightning:
   | Name                                 | Type              | Params
-----------------------------------------------------------------------
0  | model                                | MainModel         | 11 M  
1  | model.tf_map                         | TiledFeaturesMap  | 0     
2  | model.backbone                       | ResNet            | 11 M  
3  | model.backbone.conv1                 | Sequential        | 8 K   
4  | model.backbone.conv1.0               | BatchNorm2d       | 128   
5  | model.backbone.conv1.1               | Conv2d            | 4 K   
6  | model.backbone.conv1.2               | BatchNorm2d       | 128   
7  | model.backbone.conv1.3               | ReLU              | 0     
8  | model.backbone.conv1.4               | Conv2d            | 4 K   
9  | model.backbone.bn1                   | BatchNorm2d       | 128   
10 | model.backbone.relu                  | ReLU              | 0     
11 | model.backbone.maxpool               | Identity        

[Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.000125
    lr: 0.000125
    weight_decay: 0.0001
)] [{'interval': 'epoch', 'scheduler': <lib.schedulers.DelayedScheduler object at 0x7f4b694f8650>}]




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
165*3