In [1]:
import os
cuda_devices = '0, 1, 2, 3, 4, 5, 6, 7'
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_devices

import torch
import gc
torch.cuda.empty_cache()
gc.collect()

import pytorch_lightning as pl

from callbacks import SaveImagesCallback, LoadImageToTensorBoard

from datasets_m import HwrOnPrintedDataModule
from losses import DiceLoss, BCEDiceLoss, TverskyLoss, FocalTverskyLoss
from models import Unet, ResNet
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds_path = '/home/kudrjavtseviv/data/dataset_1'
model_data_path = '/home/kudrjavtseviv/DL-framework/.data'

checkpoints_path = os.path.join(model_data_path, 'checkpoints')

In [13]:
class HP:
    def __init__(self, model_type, loss, loss_params, batch_size, image_height, image_width, ds_part, epoches, features) -> None:
        self.model_type = model_type
        self.loss = loss(**loss_params)
        self.loss_params = loss_params
        self.loss_name = self.loss.__name__
        self.batch_size = batch_size
        self.image_height = image_height
        self.image_width = image_width
        self.ds_part = ds_part
        self.epoches = epoches
        self.features = features
        self.model_name = f'{model_type}-{features[-1]}_{features[0]}-{image_height}_{image_width}-{self.loss_name}'
    
    def dict_for_log(self):
        return {
            'MODEL_NAME': self.model_name,
            'LOSS': self.loss_name, 
            'LOSS_PARAMS': str(self.loss_params), 
            'EPOCHES': self.epoches,
            'FEATURES': str(self.features),
            'DS_SIZE': self.ds_part,
            'IMAGE_HEIGHT': self.image_height,
            'IMAGE_WIDTH': self.image_width
        }
    
    def __str__(self):
        return str(self.dict_for_log())

In [18]:
DS_PART = None
IMAGE_HEIGHT = 800
IMAGE_WIDTH = 600
BATCH_SIZE = 2
EPOCHES = 25

configs = [
    HP(
    model_type='Unet',
    loss=DiceLoss,
    loss_params={'activation': None},
    batch_size=BATCH_SIZE,
    image_height=IMAGE_HEIGHT,
    image_width=IMAGE_WIDTH,
    ds_part=DS_PART,
    epoches=EPOCHES,
    features=[32, 64, 128, 256],
    ),
]

checkpoint_path = os.path.join('/home/kudrjavtseviv/DL-framework/.data/logs/lightning_logs/version_124/Unet-256_32-800_600-dice_loss/epoch23-val_loss0.07.ckpt')

In [19]:
hparams = configs[0]

# MODEL
model = Unet.load_from_checkpoint(checkpoint_path, map_location=torch.device('cuda'), features=hparams.features, loss=hparams.loss)

# DATA MODULE
dm = HwrOnPrintedDataModule(
    path={
        'train': os.path.join(ds_path, 'train/result'),
        'val': os.path.join(ds_path, 'val/result'),
        'test': os.path.join(ds_path, 'test/result')
    },
    batch_size=hparams.batch_size,
    num_workers=10,
    pin_memory=True if torch.cuda.is_available() else False,
    image_height=hparams.image_height,
    image_width=hparams.image_width,
    ds_part=hparams.ds_part
    )

ds_path = '/home/kudrjavtseviv/data/dataset_1'
model_data_path = '/home/kudrjavtseviv/DL-framework/.data'

logs_path = os.path.join(model_data_path, 'logs')

# UTILS
tb_logger = pl.loggers.TensorBoardLogger(save_dir=logs_path)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='_loss/val',
    dirpath=os.path.join(tb_logger.log_dir, hparams.model_name),
    filename='epoch{epoch:02d}-val_loss{_loss/val:.2f}',
    auto_insert_metric_name=False,
    save_top_k=3
)

# TRAINER
trainer = pl.Trainer(
    gpus=[1],
    max_epochs=hparams.epoches, 
    log_every_n_steps=2, 
    callbacks=[checkpoint_callback], 
    # callbacks=[checkpoint_callback, load_images_to_tb, lr_logger, early_stopping], 
    # callbacks=[checkpoint_callback, save_images_callback, load_images_to_tb, lr_logger, early_stopping], 
    logger=tb_logger)

# dm.setup('fit')
# trainer.fit(model, dm)
trainer.test(model, dm)

tb_logger.log_hyperparams(hparams.dict_for_log())
tb_logger.finalize("success")


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0, 1, 2, 3, 4, 5, 6, 7]


Testing DataLoader 0: 100%|██████████| 1008/1008 [00:49<00:00, 20.57it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        _acc/test           0.9791485667228699
       _loss/test           0.07107832282781601
     _precision/test        0.9492456316947937
      _recall/test          0.9099568128585815
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
