In [1]:
import sys
import os

import random
import math

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

sys.path.append('../..')
from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss

from lib.models.unetv1 import get_model

from sklearn.metrics import cohen_kappa_score

from tqdm.auto import tqdm

In [2]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import (patches_clean90_csv_path as patches_csv_path, patches_path,
                          patches_clean90_pkl_path as patches_pkl_path)
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [3]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

In [4]:
len(train_img_ids)

8420

In [5]:
test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [6]:
class WSIPatchesDataset1D(WSIPatchesDatasetRaw):
    def __init__(self, image_ids, csv_path=patches_csv_path,
                 path=patches_path, scale=1, transform=None, max_len=300):
        super().__init__(image_ids, csv_path, patches_path, scale, transform)
        self.max_len = max_len

    def __getitem__(self, idx):
        def trim(x):
            return x[:self.max_len]
        
        def pad(x):
            return np.pad(x[:self.max_len], ((0, self.max_len-p),)+((0, 0),)*(len(x.shape)-1), constant_values=-1)
        
        imgs, ys, xs, provider, isup_grade, gleason_score =\
            super().__getitem__(idx)
        
        p = imgs.shape[0]
        if p > self.max_len:
            imgs, ys, xs = trim(imgs), trim(ys), trim(xs)
        elif p < self.max_len:
            imgs, ys, xs = pad(imgs), pad(ys), pad(xs)
        else:
            pass

        return imgs, ys, xs, provider, isup_grade, gleason_score
    
    
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

In [7]:
max_len = 300

In [8]:
rgb_mean, rgb_std = (torch.tensor(patches_rgb_mean_av1, dtype=torch.float32), 
                     torch.tensor(patches_rgb_std_av1, dtype=torch.float32))

In [9]:
# imgs, ys, xs, provider, isup_grade, gleason_score = next(iter(train_wsipatches_dataset))

In [10]:
patches_device = torch.device('cuda:0')
# patches_device = torch.device('cpu')
main_device = torch.device('cuda:1')

In [11]:
tmp = torch.load("../Patches256TestRun/version_0/checkpoints/last.ckpt", map_location=patches_device)

In [12]:
# model = get_model(actual_lbl_nums, decoder=False, labels=False)
model = get_model(actual_lbl_nums)

module = nn.Sequential()

module.add_module('model', model)

module.to(patches_device);

module.load_state_dict(tmp['state_dict'])

model.segmentation = False
model.classification_head = None
model.autodecoder = None
module.eval();

In [13]:
rgb_mean = rgb_mean.to(patches_device)
rgb_std = rgb_std.to(patches_device)

In [14]:
from time import time

In [15]:
features_batch_size = 512

In [16]:
def get_features(imgs):
    imgs = imgs.to(patches_device)
    n_imgs = (imgs - rgb_mean) / rgb_std
    
    b_features = []
    for b in range(0, n_imgs.shape[0], features_batch_size):
        with torch.no_grad():
            features, *_ = model(n_imgs[b:b+features_batch_size], return_features=True)
            b_features.append(features)

    features = torch.cat(b_features, dim=0)
    
    return features

In [17]:
# model.training

In [18]:
#train_loader = MainBatchGenerator1D(train_img_ids, batch_size=32, patches_batch_size=2, shuffle=True, 
#                                    num_workers=6, max_len=300, patches_csv_path=patches_csv_path,
#                                    scale=0.5, transform=augment_empty_clr_only)
#                                    # scale=0.5, transform=augment_v1_clr_only)

In [19]:
train_wsipatches_dataset = WSIPatchesDatasetRaw(train_img_ids, patches_pkl_path, 
                                                scale=0.5, transform=augment_v1_clr_only)

In [20]:
#for features, target in tqdm(train_loader, total=len(train_loader)):
#    pass

In [21]:
# from multiprocessing import Pool
import multiprocessing.dummy as mp
# import torch.multiprocessing as mp

In [22]:
#mp.set_start_method('spawn')

In [None]:
# multiprocessing.dummy, processes=6, imgs = torch.from_numpy(imgs).to(patches_device) ~ 19 минут

import multiprocessing.dummy as mp

idxs = list(range(len(train_wsipatches_dataset)))

def process_item(idx):
    item_data = train_wsipatches_dataset[idx]
    return np.array([1.0, 2.0], dtype=np.float32), 500 #item_data

with mp.Pool(processes=6) as pool:
    for item_data in tqdm(pool.imap_unordered(process_item, idxs), total=len(idxs)):
        imgs, *_ = item_data
            
        imgs = torch.tensor(imgs)
        #imgs = torch.from_numpy(imgs) #.half()
        #imgs = torch.from_numpy(imgs).to(patches_device)
        #featurtes = get_features(imgs).cpu()

In [168]:
a = np.random.random((200, 3, 256, 256)).astype(np.float32)

In [169]:
at = torch.tensor(a)

In [195]:
%%time
for _ in range(10):
    # b = torch.from_numpy(a).to(patches_device)
    # b = torch.tensor(a, device=patches_device)
    b = torch.zeros_like(b)
    b[...] = at[...]
    b.sum().item()

CPU times: user 149 ms, sys: 61 µs, total: 149 ms
Wall time: 148 ms


In [None]:
idxs = list(range(len(train_wsipatches_dataset)))

In [24]:
def process_item(idx):
    item_data = train_wsipatches_dataset[idx]
    imgs, *_ = item_data
    imgs = torch.tensor(imgs)
    return imgs

new_rows = []
with mp.Pool(processes=6) as pool:
    for item_data in tqdm(pool.imap_unordered(process_item, idxs), total=len(idxs)):
        # imgs, *_ = item_data
        
        # imgs = torch.from_numpy(imgs).to(patches_device)
        imgs = item_data
        featurtes = get_features(imgs.to(patches_device)).cpu()

HBox(children=(FloatProgress(value=0.0, max=8420.0), HTML(value='')))

Process ForkPoolWorker-5:
Process ForkPoolWorker-4:
Process ForkPoolWorker-2:
Traceback (most recent call last):
Process ForkPoolWorker-3:
  File "/home/ruslan/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process ForkPoolWorker-6:
  File "/home/ruslan/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ruslan/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/ruslan/anaconda3/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ruslan/anaconda3/lib/python3.7/multiprocessing/pool.py", line 121, in worker
    result = (True, func(*args, **kwds))
  File "<ipython-input-24-0ded585516ed>", line 2, in process_item
    item_data = train_

  File "/home/ruslan/anaconda3/lib/python3.7/site-packages/imgaug/augmenters/meta.py", line 431, in augment_images
    hooks=hooks
  File "/home/ruslan/anaconda3/lib/python3.7/site-packages/imgaug/augmenters/arithmetic.py", line 242, in _augment_images
    samples = np.tile(samples, (1, 1, nb_channels))
  File "<__array_function__ internals>", line 6, in tile
  File "/home/ruslan/anaconda3/lib/python3.7/site-packages/numpy/lib/shape_base.py", line 1242, in tile
    c = c.reshape(-1, n).repeat(nrep, 0)
KeyboardInterrupt
  File "/home/ruslan/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 333, in reduce_storage
    fd, size = storage._share_fd_()
KeyboardInterrupt
  File "/home/ruslan/anaconda3/lib/python3.7/site-packages/albumentations/augmentations/functional.py", line 723, in median_blur
    return blur_fn(img)
  File "/home/ruslan/anaconda3/lib/python3.7/site-packages/albumentations/core/transforms_interface.py", line 87, in __call__
    return self.a




KeyboardInterrupt: 

In [None]:
# multiprocessing.dummy, processes=6, imgs = torch.from_numpy(imgs).to(patches_device) ~ 19 минут

from multiprocessing.dummy import Pool

idxs = list(range(len(train_wsipatches_dataset)))

def process_item(idx):
    item_data = train_wsipatches_dataset[idx]
    return item_data

with Pool(processes=6) as pool:
    for item_data in tqdm(pool.imap_unordered(process_item, idxs), total=len(idxs)):
        imgs, *_ = item_data
            
        imgs = torch.from_numpy(imgs).to(patches_device)
        featurtes = get_features(imgs).cpu()

In [None]:
#for imgs, *_ in tqdm(train_wsipatches_dataset, total=len(train_wsipatches_dataset)):
#    imgs = torch.from_numpy(imgs).to(patches_device)
#    # imgs = torch.tensor(imgs) # .to(patches_device)
#    featurtes = get_features(imgs).cpu()

In [None]:
#imgs.shape