In [1]:
import sys
sys.path.append('../..')

In [2]:
import os
import psutil

import random
import math
from functools import partial

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

import multiprocessing.dummy as mp

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger


from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw, WSIPatchesDummyDataloader
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss
from lib.trainers import WSIModuleV1

from lib.models.unetv1 import get_model
from lib.models.features_map import FeaturesMap, TiledFeaturesMap, RearrangedFeaturesMap
from lib.models.wsi_resnets import ResnetTiled_64x8x8, Resnet_64x8x8

from sklearn.metrics import cohen_kappa_score

from tqdm.auto import tqdm

import matplotlib.pyplot as plt

In [3]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import (patches_clean90_csv_path as patches_csv_path, patches_path,
                          patches_clean90_pkl_path as patches_pkl_path)
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [4]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [5]:
from lib.dataloaders import WSIPatchesDataloader, WSIPatchesDatasetRaw
from lib.utils import get_pretrained_model, get_features

In [6]:
batch_size = 64

In [7]:
#train_batch_path = '/mnt/HDDData/pdata/processed/pretrained_64x8x8/train/{}/'
#test_batch_path = '/mnt/HDDData/pdata/processed/pretrained_64x8x8/val/'

train_batch_path = '/mnt/SSDData/pdata/processed/pretrained_64x8x8/train/{}/'
test_batch_path = '/mnt/SSDData/pdata/processed/pretrained_64x8x8/val/'

train_loader = WSIPatchesDummyDataloader(train_batch_path, precalc_epochs=11, batch_size=batch_size, shuffle=True)
val_loader = WSIPatchesDummyDataloader(test_batch_path, precalc_epochs=11, batch_size=batch_size, shuffle=False)

In [8]:
steps_in_epoh = 1

epochs = 90

warmup_epochs = 0
warmup_steps = 0

hparams = {
    'batch_size': batch_size,
    'learning_rate': 0.001 * 64 / 256,
    'dataset': {
        'dataloader': 'dummy',
        'rgb_mean': patches_rgb_mean_av1,
        'rgb_std': patches_rgb_std_av1,
        'classes': max_lbl_nums,
        'precalc_epochs': 50,
        'train_test_split': {},
    },
    'optimizer': {
        'name': 'torch.optim.Adam',
        'params': {
            'weight_decay': 1e-4
        }
    },
    'scheduler': {
        'name': 'lib.schedulers.ExponentialLRWithMin',
        'params': {
            'gamma': 0.92,
            'eta_min': 1.25e-5
        },
        'interval': 'epoch'
    },
    'loss': {
        'weights': {
            'reg': 1 / 2,
            'class': 9 / 2
        },
        'label_smoothing': 0.1
    },
    'warmup_steps': warmup_steps,
    'steps_in_epoh': steps_in_epoh,
    'epochs': epochs
}

In [9]:
steps_in_epoh = len(train_loader)

In [10]:
steps_in_epoh

132

In [11]:
hparams['steps_in_batch'] = steps_in_epoh
if 'T_max' in hparams['scheduler']['params']:
    hparams['scheduler']['params']['T_max'] = (epochs * steps_in_epoh -
                                               warmup_steps)

In [12]:
from lib.models.tresnet_models.tresnet.tresnet import TResNet, TResnetM

In [13]:
#backbone = TResnetM({'num_classes': 10, 'remove_aa_jit': False})

In [14]:
#device = 'cuda:1'

In [15]:
#backbone.to(device);

In [16]:
#sum([p.data.numel() for p in model.backbone.parameters()])

In [17]:
from torchvision import models

In [18]:
from inplace_abn import ABN, InPlaceABN

In [19]:
from lib.models.abn_models.models import net_resnet18

In [20]:
class RearrangedResnet_64x8x8(nn.Module):
    def __init__(self, backbone, backbone_features, classes, features_do, h=20, w=20):
        super().__init__()

        self.rf_map = RearrangedFeaturesMap(False, 64, f_size=8, h=h, w=w)

        # self.backbone = net_resnet18(classes=512) 
        # self.backbone.mod1 = nn.Sequential()
        self.backbone = getattr(models, backbone)(pretrained=False)
        self.backbone.conv1 = nn.Sequential(
            # nn.Dropout2d(features_do) if features_do > 0 else nn.Identity(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 1),
        )

        self.backbone.fc = nn.Linear(backbone_features, 512)
        self.backbone.maxpool = nn.Identity()

        self.reg_linear = nn.Linear(512, 1)
        self.class_linear = nn.Linear(512, classes)

    def forward(self, features, ys, xs):
        f_map = self.rf_map(features, ys, xs)
        x = self.backbone(f_map)
        return self.reg_linear(x), self.class_linear(x)

In [40]:
model = RearrangedResnet_64x8x8(
    backbone='resnet18',
    backbone_features=512,
    classes=max_lbl_nums,
    features_do=0)

In [22]:
# model.backbone

In [23]:
#from torchsummary import summary

In [24]:
# summary(model.backbone.cuda(), (64, 7*8, 7*8), -1, 'cuda')

In [25]:
#batch = next(iter(train_loader))

In [26]:
#features, ys, xs, provider, isup_grade, gleason_score = batch

In [27]:
#with torch.no_grad():
#    tmp = model.rf_map(features, ys, xs)

In [28]:
#tmp.shape

In [29]:
#device = torch.device('cuda:0')

In [30]:
#model.to(device);

In [31]:
#train_loader = WSIPatchesDummyDataloader(train_batch_path, precalc_epochs=6, 
#                                         batch_size=64, shuffle=True)

In [32]:
#batch = next(iter(train_loader))

In [33]:
#features, ys, xs, provider, isup_grade, gleason_score = batch

In [34]:
#features, ys, xs = features.to(device), ys.to(device), xs.to(device)

In [35]:
#tmp = model(features, ys, xs)

In [36]:
#xxx

In [37]:
#f_ns, f_tiles = model.tf_map(features, ys, xs)

In [38]:
#summary(model.backbone, (64, 70*8, 40*8), device='cpu')

In [41]:
module = WSIModuleV1(model, hparams, log_train_every_batch=False)

In [42]:
trainer = Trainer(max_epochs=hparams['epochs'], gpus=[0,], fast_dev_run=False, num_sanity_val_steps=0)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]


In [43]:
trainer.fit(module, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(trainer.checkpoint_callback.dirpath,
                                     "last.ckpt"))

INFO:lightning:
   | Name                                 | Type                    | Params
-----------------------------------------------------------------------------
0  | model                                | RearrangedResnet_64x8x8 | 11 M  
1  | model.rf_map                         | RearrangedFeaturesMap   | 0     
2  | model.backbone                       | ResNet                  | 11 M  
3  | model.backbone.conv1                 | Sequential              | 12 K  
4  | model.backbone.conv1.0               | BatchNorm2d             | 128   
5  | model.backbone.conv1.1               | Conv2d                  | 4 K   
6  | model.backbone.conv1.2               | BatchNorm2d             | 128   
7  | model.backbone.conv1.3               | ReLU                    | 0     
8  | model.backbone.conv1.4               | Conv2d                  | 4 K   
9  | model.backbone.conv1.5               | BatchNorm2d             | 128   
10 | model.backbone.conv1.6               | ReLU           

[Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.00025
    lr: 0.00025
    weight_decay: 0.0001
)] [{'interval': 'epoch', 'scheduler': <lib.schedulers.DelayedScheduler object at 0x7fca9bbe1310>}]




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





In [None]:
165*3