In [1]:
import sys
import os
import psutil

import random
import math
from functools import partial

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

import multiprocessing.dummy as mp

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

sys.path.append('../..')
from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw, WSIPatchesDummyDataloader
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss
from lib.trainers import GeneralModule

from lib.models.unetv1 import get_model
from lib.models.features_map import FeaturesMap

from sklearn.metrics import cohen_kappa_score

from tqdm.auto import tqdm

import matplotlib.pyplot as plt

In [2]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import (patches_clean90_csv_path as patches_csv_path, patches_path,
                          patches_clean90_pkl_path as patches_pkl_path)
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [3]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [4]:
from lib.dataloaders import WSIPatchesDataloader, WSIPatchesDatasetRaw
from lib.utils import get_pretrained_model, get_features

In [5]:
# patches_device = torch.device('cuda:0')

#rgb_mean, rgb_std = (torch.tensor(patches_rgb_mean_av1, dtype=torch.float32, device=patches_device), 
#                     torch.tensor(patches_rgb_std_av1, dtype=torch.float32, device=patches_device))

#model = get_pretrained_model(get_model, {'classes': actual_lbl_nums}, 
#                             "../Patches256TestRun/version_0/checkpoints/last.ckpt", patches_device)

#get_features_fn = partial(get_features, model=model, device=patches_device, 
#                          rgb_mean=rgb_mean, rgb_std=rgb_std, 
#                          features_batch_size=512)

In [6]:
main_batch_size = 64

In [7]:
train_batch_path = '/mnt/SSDData/pdata/processed/pretrained/train/{}/'
test_batch_path = '/mnt/SSDData/pdata/processed/pretrained/val/'

train_loader = WSIPatchesDummyDataloader(train_batch_path, precalc_epochs=50, shuffle=True)
val_loader = WSIPatchesDummyDataloader(test_batch_path, precalc_epochs=50, shuffle=False)

In [8]:
'''
train_loader = WSIPatchesDataloader(WSIPatchesDatasetRaw(train_img_ids, patches_pkl_path, 
                                                         scale=0.5, transform=augment_v1_clr_only),
                                    get_features_fn, (512, 8, 8),
                                    main_batch_size, shuffle=True, num_workers=5, max_len=300)

val_loader = WSIPatchesDataloader(WSIPatchesDatasetRaw(test_img_ids, patches_pkl_path, 
                                                       scale=0.5, transform=augment_empty_clr_only),
                                    get_features_fn, (512, 8, 8),
                                    main_batch_size, shuffle=True, num_workers=5, max_len=300)
''';

In [9]:
class WSIModule1DV1(GeneralModule):
    def __init__(self, model, hparams, log_train_every_batch=False):
        super().__init__(model, hparams, log_train_every_batch)
        self.hparams = hparams
        self.model = model
        self.log_train_every_batch = log_train_every_batch

        self.reg_loss = nn.MSELoss()

        label_smoothing = self.hparams['loss']['label_smoothing']
        self.class_loss = SmoothLoss(nn.KLDivLoss(), smoothing=label_smoothing,
                                     one_hot_target=True)

        self.loss_weights = self.hparams['loss']['weights']

        self.rgb_mean = torch.tensor(hparams['dataset']['rgb_mean'],
                                     dtype=torch.float32)
        self.rgb_std = torch.tensor(hparams['dataset']['rgb_std'],
                                    dtype=torch.float32)
        
        self.process = psutil.Process(os.getpid())
        
        self.max_lbl_nums = hparams['dataset']['classes']
        
    @classmethod
    def _accuracy(cls, output, target):
        pred = output
        eq = pred.eq(target.view_as(pred))
        return eq.float().mean()

    def step(self, batch, batch_idx, is_train):
        # features, ys, xs, provider, isup_grade, gleason_score = batch 
        # features = features.mean(-1).mean(-1).transpose(1, -1)
        features, ys, xs, provider, isup_grade, gleason_score = batch
        
        b = features.shape[0]
        
        labels = isup_grade
        
        labels_reg = labels[:, None].float()
        labels_class = labels        
        
        o_labels_reg, o_labels_class = self.model(features, ys, xs)
        
        o_labels_reg = torch.sigmoid(o_labels_reg) * self.max_lbl_nums - 0.5
        o_labels_class = F.log_softmax(o_labels_class, dim=-1)
        
        reg_loss = self.reg_loss(o_labels_reg, labels_reg)
        class_loss = self.class_loss(o_labels_class, labels_class)

        loss = (self.loss_weights['reg'] * reg_loss +
                self.loss_weights['class'] * class_loss)

        o_labels_reg = o_labels_reg.round().long().clamp(0, self.max_lbl_nums-1)
        o_labels_class = o_labels_class.argmax(dim=-1)        
        
        acc_reg = self._accuracy(o_labels_reg, labels)
        acc_class = self._accuracy(o_labels_class, labels) 
        
        qwk_reg = cohen_kappa_score(o_labels_reg.cpu().numpy(), labels.cpu().numpy(), weights="quadratic")
        qwk_class = cohen_kappa_score(o_labels_class.cpu().numpy(), labels.cpu().numpy(), weights="quadratic")        
        
        lr = self.optimizer.param_groups[0]['lr']

        pr = '' if is_train else 'val_'

        log_dict = {
            pr+'loss': loss,
            pr+'reg_loss': float(reg_loss),
            pr+'class_loss': float(class_loss),
            pr+'acc_reg': float(acc_reg),
            pr+'acc_class': float(acc_class),            
            pr+'qwk_reg': float(qwk_reg),
            pr+'qwk_class': float(qwk_class),            
            pr+'lr': lr,
            pr+'memory': self.process.memory_info().rss
        }

        if is_train and self.log_train_every_batch:
            return {'loss': loss, 'log': log_dict}
        else:
            return log_dict

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, True)

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, False)

    def _apply(self, fn):
        super()._apply(fn)
        self.rgb_mean = fn(self.rgb_mean)
        self.rgb_std = fn(self.rgb_std)

        return self

In [10]:
steps_in_epoh = 1

epochs = 90

warmup_epochs = 0
# warmup_steps = 132
warmup_steps = 0
batch_size = 64

hparams = {
    'batch_size': batch_size,
    #'learning_rate': 0.1 * batch_size / 256,
    #'learning_rate': 0.1 * batch_size / 256 / 2.5,
    'learning_rate': 0.001 * batch_size / 256,
    #'learning_rate': 0.01 * batch_size / 256,
    #'learning_rate': 0.01 * batch_size / 256 * 4,
    #'learning_rate': 1.25e-5,
    'dataset': {
        'rgb_mean': patches_rgb_mean_av1,
        'rgb_std': patches_rgb_std_av1,
        'classes': max_lbl_nums
    },
    'optimizer': {
        'name': 'torch.optim.Adam',
        'params': {
             #'momentum': 0.9,
             #'weight_decay': 2e-3
             'weight_decay': 1e-4
             #'weight_decay': 0
        }
    },
    'scheduler': {
         #'name': 'torch.optim.lr_scheduler.ExponentialLR',
        #'name': 'torch.optim.lr_scheduler.CosineAnnealingLR',
        'name': 'lib.schedulers.ExponentialLRWithMin',
        'params': {
            #'gamma': 1.0,
            'gamma': 0.92,
            #'T_max': epochs * steps_in_epoh - warmup_steps,
            #'eta_min': 1e-3
            'eta_min': 1.25e-5
        },
        #'interval': 'step'
        'interval': 'epoch'
    },
    'loss': {
        'weights': {
            'reg': 1 / 2, 
            'class': 9 / 2
        },
        'label_smoothing': 0.1
    },
    'warmup_steps': warmup_steps,
    'steps_in_epoh': steps_in_epoh,
    'epochs': epochs,
    # 'source_code': open(__file__, 'rt').read()
}

In [11]:
steps_in_epoh = len(train_loader)

In [12]:
steps_in_epoh

132

In [13]:
hparams['steps_in_batch'] = steps_in_epoh
#hparams['scheduler']['params']['T_max'] = (epochs * steps_in_epoh -
#                                           warmup_steps)

In [14]:
import torchvision.models as models

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

f_d_rate = 0.0
d_rate = 0.0

max_height = 70
max_width = 40

class MainModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.f_map = FeaturesMap(True, 512, max_height, max_width)
        
        # self.backbone = models.resnext50_32x4d(pretrained=False)
        self.backbone = models.resnet18(pretrained=False)
        # self.backbone.conv1 = nn.Conv2d(512, 64, 1)
        self.backbone.conv1 = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, 1),
        )
        # self.backbone.fc = nn.Linear(2048, 512)
        self.backbone.fc = nn.Linear(512, 512)
        self.backbone.maxpool = nn.Identity()
        
        self.reg_linear = nn.Linear(512, 1)
        self.class_linear = nn.Linear(512, max_lbl_nums)

    def forward(self, features, ys, xs):
        f_map = self.f_map(features, ys, xs)
        x = self.backbone(f_map)
        return self.reg_linear(x), self.class_linear(x)

model = MainModel()

In [15]:
# model.backbone

In [16]:
#from torchsummary import summary

In [17]:
#summary(model.backbone, (512, 70, 40), device='cpu')

In [18]:
module = WSIModule1DV1(model, hparams, log_train_every_batch=False)

In [19]:
trainer = Trainer(max_epochs=hparams['epochs'], gpus=[1,], fast_dev_run=False, num_sanity_val_steps=0)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [1]


In [None]:
trainer.fit(module, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(trainer.checkpoint_callback.dirpath,
                                     "last.ckpt"))

[Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.00025
    lr: 0.00025
    weight_decay: 0.0001
)] [{'interval': 'epoch', 'scheduler': <lib.schedulers.DelayedScheduler object at 0x7fd4856822d0>}]


INFO:lightning:
   | Name                                 | Type              | Params
-----------------------------------------------------------------------
0  | model                                | MainModel         | 13 M  
1  | model.f_map                          | FeaturesMap       | 1 M   
2  | model.backbone                       | ResNet            | 11 M  
3  | model.backbone.conv1                 | Sequential        | 174 K 
4  | model.backbone.conv1.0               | BatchNorm2d       | 1 K   
5  | model.backbone.conv1.1               | Conv2d            | 131 K 
6  | model.backbone.conv1.2               | BatchNorm2d       | 512   
7  | model.backbone.conv1.3               | ReLU              | 0     
8  | model.backbone.conv1.4               | Conv2d            | 32 K  
9  | model.backbone.conv1.5               | BatchNorm2d       | 256   
10 | model.backbone.conv1.6               | ReLU              | 0     
11 | model.backbone.conv1.7               | Conv2d          

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…