In [1]:
import sys
import os
import psutil

import random
import math

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

import multiprocessing.dummy as mp

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

sys.path.append('../..')
from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss

from lib.models.unetv1 import get_model

from sklearn.metrics import cohen_kappa_score

from tqdm.auto import tqdm

import matplotlib.pyplot as plt

In [2]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import (patches_clean90_csv_path as patches_csv_path, patches_path,
                          patches_clean90_pkl_path as patches_pkl_path)
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [3]:
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)
    
def get_pretrained_model(get_model_fn, checkpoint, device):
    tmp = torch.load(checkpoint, map_location=device)
    
    model = get_model_fn(actual_lbl_nums)

    module = nn.Sequential()

    module.add_module('model', model)

    module.to(device);

    module.load_state_dict(tmp['state_dict'])

    model.segmentation = False
    model.classification_head = None
    model.autodecoder = None
    module.eval();
    
    return model

def get_features(imgs, features_batch_size=512):
    model.eval()
    
    imgs = imgs if isinstance(imgs, torch.Tensor) else torch.from_numpy(imgs)
    imgs = imgs.to(patches_device)
    n_imgs = (imgs - rgb_mean) / rgb_std
    
    b_features = []
    for b in range(0, n_imgs.shape[0], features_batch_size):
        with torch.no_grad():
            features, *_ = model(n_imgs[b:b+features_batch_size], return_features=True)
            b_features.append(features)

    features = torch.cat(b_features, dim=0)
    
    return features.cpu()

In [4]:
patches_device = torch.device('cuda:0')
# patches_device = torch.device('cpu')
main_device = torch.device('cuda:1')

In [5]:
model = get_pretrained_model(get_model, "../Patches256TestRun/version_0/checkpoints/last.ckpt", patches_device)

In [6]:
# model = torch.jit.script(model)

In [7]:
model.__class__.forward

<function segmentation_models_pytorch.base.model.SegmentationModel.forward(self, x)>

In [8]:
rgb_mean, rgb_std = (torch.tensor(patches_rgb_mean_av1, dtype=torch.float32, device=patches_device), 
                     torch.tensor(patches_rgb_std_av1, dtype=torch.float32, device=patches_device))

In [9]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [10]:
dataset = WSIPatchesDatasetRaw(train_img_ids, patches_pkl_path, 
                               scale=0.5, transform=augment_v1_clr_only)

In [11]:
process = psutil.Process(os.getpid())

In [None]:
for _ in range(5):        
    memory = []

    main_batch_size = 64

    max_len = 300

    idxs = list(range(len(dataset)))
    random.shuffle(idxs)

    def process_item(idx):
        item_data = dataset[idx]
        return item_data

    b_features = torch.zeros((main_batch_size, max_len, 512, 8, 8), dtype=torch.float32)
    b_ys = torch.zeros((main_batch_size, max_len), dtype=torch.int64)
    b_xs = torch.zeros((main_batch_size, max_len), dtype=torch.int64)
    b_provider = torch.zeros((main_batch_size), dtype=torch.int64)
    b_isup_grade = torch.zeros((main_batch_size), dtype=torch.int64)
    b_gleason_score = torch.zeros((main_batch_size), dtype=torch.int64)

    batch = [b_features, b_ys, b_xs, b_provider, 
             b_isup_grade, b_gleason_score]

    def clean_batch():
        for a in batch:
            a.fill_(-1)

    clean_batch()

    c_iter = 0
    with mp.Pool(processes=6) as pool:
        for item_data in tqdm(pool.imap_unordered(process_item, idxs), total=len(dataset)):
        # for item_data in pool.imap_unordered(process_item, idxs):
            imgs, ys, xs, provider, isup_grade, gleason_score = item_data
            # imgs = torch.from_numpy(imgs).to(patches_device)
            features = get_features(imgs)

            b_iter = c_iter % main_batch_size
            p = ys.shape[0]

            b_features[b_iter, :p] = features[:max_len]
            b_ys[b_iter, :p] = torch.from_numpy(ys)[:max_len]        
            b_xs[b_iter, :p] = torch.from_numpy(xs)[:max_len]
            b_provider[b_iter] = provider
            b_isup_grade[b_iter] = isup_grade        
            b_gleason_score[b_iter] = gleason_score

            if (c_iter + 1) % main_batch_size == 0:
                #process batch
                #yield batch

                # clean batch data
                clean_batch()

            c_iter += 1

            memory.append(process.memory_info().rss)

    plt.axes().ticklabel_format(style='sci', scilimits=(9, 9))
    plt.plot(memory);
    plt.show()

In [12]:
class WSIPatchesDataloader():
    def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0, max_len=300):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.max_len = max_len
    
    def __len__(self):
        return math.ceil(len(self.dataset) / self.batch_size)
    
    def __iter__(self):
        return self.produce_batches()
    
    def produce_batches(self):
        def process_item(idx):
            item_data = self.dataset[idx]
            return item_data

        def clean_batch():
            for a in batch:
                a.fill_(-1)        
        
        idxs = list(range(len(self.dataset)))
        
        if self.shuffle:
            random.shuffle(idxs)
            
        max_len = self.max_len

        b_features = torch.zeros((main_batch_size, max_len, 512, 8, 8), dtype=torch.float32)
        b_ys = torch.zeros((main_batch_size, max_len), dtype=torch.int64)
        b_xs = torch.zeros((main_batch_size, max_len), dtype=torch.int64)
        b_provider = torch.zeros((main_batch_size), dtype=torch.int64)
        b_isup_grade = torch.zeros((main_batch_size), dtype=torch.int64)
        b_gleason_score = torch.zeros((main_batch_size), dtype=torch.int64)

        batch = [b_features, b_ys, b_xs, b_provider, 
                 b_isup_grade, b_gleason_score]

        clean_batch()

        c_iter = 0
        with mp.Pool(processes=self.num_workers) as pool:
            # for item_data in tqdm(pool.imap_unordered(process_item, idxs), total=len(dataset)):
            for item_data in pool.imap_unordered(process_item, idxs):
                imgs, ys, xs, provider, isup_grade, gleason_score = item_data
                features = get_features(imgs)

                b_iter = c_iter % main_batch_size
                p = ys.shape[0]

                b_features[b_iter, :p] = features[:max_len]
                b_ys[b_iter, :p] = torch.from_numpy(ys)[:max_len]        
                b_xs[b_iter, :p] = torch.from_numpy(xs)[:max_len]
                b_provider[b_iter] = provider
                b_provider[b_isup_grade] = isup_grade        
                b_provider[b_gleason_score] = gleason_score

                if (c_iter + 1) % main_batch_size == 0:
                    #process batch
                    yield batch

                    # clean batch data
                    clean_batch()

                c_iter += 1

        if c_iter % main_batch_size != 0:        
            yield [a[:c_iter % main_batch_size] for a in batch]

In [None]:
main_batch_size = 64

In [None]:
train_loader = WSIPatchesDataloader(WSIPatchesDatasetRaw(train_img_ids, patches_pkl_path, 
                                                         scale=0.5, transform=augment_v1_clr_only), 
                                    main_batch_size, shuffle=True, num_workers=6, max_len=300)

In [None]:
memory = []
for data in tqdm(train_loader, total=len(train_loader)):
    memory.append(process.memory_info().rss)

In [None]:
plt.axes().ticklabel_format(style='sci', scilimits=(9, 9))
plt.plot(memory);

In [None]:
memory = []
for data in tqdm(train_loader, total=len(train_loader)):
    memory.append(process.memory_info().rss)

In [None]:
plt.axes().ticklabel_format(style='sci', scilimits=(9, 9))
plt.plot(memory);

In [None]:
memory = []
for data in tqdm(train_loader, total=len(train_loader)):
    memory.append(process.memory_info().rss)

In [None]:
plt.axes().ticklabel_format(style='sci', scilimits=(9, 9))
plt.plot(memory);

In [12]:
def get_batches():
    main_batch_size = 64

    max_len = 300

    idxs = list(range(len(dataset)))
    random.shuffle(idxs)

    def process_item(idx):
        item_data = dataset[idx]
        return item_data

    b_features = torch.zeros((main_batch_size, max_len, 512, 8, 8), dtype=torch.float32)
    b_ys = torch.zeros((main_batch_size, max_len), dtype=torch.int64)
    b_xs = torch.zeros((main_batch_size, max_len), dtype=torch.int64)
    b_provider = torch.zeros((main_batch_size), dtype=torch.int64)
    b_isup_grade = torch.zeros((main_batch_size), dtype=torch.int64)
    b_gleason_score = torch.zeros((main_batch_size), dtype=torch.int64)

    batch = [b_features, b_ys, b_xs, b_provider, 
             b_isup_grade, b_gleason_score]

    def clean_batch():
        for a in batch:
            a.fill_(-1)

    clean_batch()

    c_iter = 0
    with mp.Pool(processes=6) as pool:
        # for item_data in tqdm(pool.imap_unordered(process_item, idxs), total=len(dataset)):
        for item_data in pool.imap_unordered(process_item, idxs):
            imgs, ys, xs, provider, isup_grade, gleason_score = item_data
            imgs = torch.from_numpy(imgs).to(patches_device)
            features = get_features(imgs)

            b_iter = c_iter % main_batch_size
            p = ys.shape[0]

            b_features[b_iter, :p] = features[:max_len]
            b_ys[b_iter, :p] = torch.from_numpy(ys)[:max_len]        
            b_xs[b_iter, :p] = torch.from_numpy(xs)[:max_len]
            b_provider[b_iter] = provider
            b_provider[b_isup_grade] = isup_grade        
            b_provider[b_gleason_score] = gleason_score

            if (c_iter + 1) % main_batch_size == 0:
                #process batch
                yield batch

                # clean batch data
                clean_batch()

            c_iter += 1

    if c_iter % main_batch_size != 0:        
        yield [a[:c_iter % main_batch_size] for a in batch]

In [None]:
memory = []
for data in tqdm(get_batches(), total=132):
    memory.append(process.memory_info().rss)
    features, ys, xs, provider, isup_grade, gleason_score = data 
    print(isup_grade)

HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))

tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])
tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])
tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])
tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -

In [None]:
plt.axes().ticklabel_format(style='sci', scilimits=(9, 9))
plt.plot(memory);