In [1]:
import sys
import os
import psutil

import random
import math
from functools import partial

import torch 
from torch import optim
from torch.optim import lr_scheduler
from torch import nn
from torch.nn import functional as F

import multiprocessing.dummy as mp

from pytorch_lightning import Trainer
from pytorch_lightning.core import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

sys.path.append('../..')
from lib.schedulers import DelayedScheduler
from lib.datasets import (max_lbl_nums, actual_lbl_nums, 
                          patches_rgb_mean_av1, patches_rgb_std_av1, 
                          get_train_test_img_ids_split)
from lib.dataloaders import PatchesDataset, WSIPatchesDatasetRaw
from lib.augmentations import augment_v1_clr_only, augment_empty_clr_only
from lib.losses import SmoothLoss

from lib.models.unetv1 import get_model

from sklearn.metrics import cohen_kappa_score

from tqdm.auto import tqdm

import matplotlib.pyplot as plt

In [2]:
# import cv2
import numpy as np
# import pandas as pd
# from lib.datasets import patches_csv_path, patches_path
from lib.datasets import (patches_clean90_csv_path as patches_csv_path, patches_path,
                          patches_clean90_pkl_path as patches_pkl_path)
# from lib.dataloaders import imread, get_g_score_num, get_provider_num

In [3]:
patches_device = torch.device('cuda:0')
# patches_device = torch.device('cpu')
main_device = torch.device('cuda:1')

In [4]:
rgb_mean, rgb_std = (torch.tensor(patches_rgb_mean_av1, dtype=torch.float32, device=patches_device), 
                     torch.tensor(patches_rgb_std_av1, dtype=torch.float32, device=patches_device))

In [5]:
train_img_ids, test_img_ids = get_train_test_img_ids_split()

test_img_ids[:4]

['e8baa3bb9dcfb9cef5ca599d62bb8046',
 '9b2948ff81b64677a1a152a1532c1a50',
 '5b003d43ec0ce5979062442486f84cf7',
 '375b2c9501320b35ceb638a3274812aa']

In [6]:
from lib.dataloaders import WSIPatchesDataloader, WSIPatchesDatasetRaw
from lib.utils import get_pretrained_model, get_features

In [7]:
model = get_pretrained_model(get_model, {'classes': actual_lbl_nums}, 
                             "../Patches256TestRun/version_0/checkpoints/last.ckpt", patches_device)

In [8]:
get_features_fn = partial(get_features, model=model, device=patches_device, 
                          rgb_mean=rgb_mean, rgb_std=rgb_std, 
                          features_batch_size=512)

In [9]:
process = psutil.Process(os.getpid())

In [10]:
main_batch_size = 64

In [11]:
train_loader = WSIPatchesDataloader(WSIPatchesDatasetRaw(train_img_ids, patches_pkl_path, 
                                                         scale=0.5, transform=augment_v1_clr_only),
                                    get_features_fn, (512, 8, 8),
                                    main_batch_size, shuffle=True, num_workers=6, max_len=300)

In [None]:
for _ in range(5):
    memory = []
    for data in tqdm(train_loader, total=len(train_loader)):
        memory.append(process.memory_info().rss)
        
    plt.axes().ticklabel_format(style='sci', scilimits=(9, 9))
    plt.plot(memory)
    plt.show()

HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))