In [1]:
import sys, os
sys.path.append("/home/ziyang/kaggle/hubmap-organ-segmentation")

In [2]:
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple

import timm
import torch
import torch.nn as nn
import pandas as pd

import segmentation_models_pytorch as smp


from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm

In [3]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from data import LitDataModule
from model import LitModule

def train(
    cfg,
    fold: int,
    train_csv_path: str,
    test_csv_path: str,
) -> None:
    pl.seed_everything(cfg.seed)

    data_module = LitDataModule(
        val_fold=fold,
        train_csv_path=train_csv_path,
        test_csv_path=test_csv_path,
        spatial_size=cfg.data.spatial_size,
        batch_size=cfg.data.batch_size,
        num_workers=cfg.data.num_workers,
    )

    data_module.setup()

    module = LitModule(cfg)

    model_checkpoint = ModelCheckpoint(cfg.train.checkpoint_dir,
                                        monitor="val_dice_th",
                                        mode="max",
                                        verbose=True,
                                        filename=f"{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}",
                                        )

    trainer = pl.Trainer(
        default_root_dir=cfg.train.checkpoint_dir,
        gpus=cfg.train.gpus,
        benchmark=True,
        deterministic=False,
        callbacks=[model_checkpoint],
        limit_train_batches=1.0,
        limit_val_batches=1.0,
        log_every_n_steps=5,
        logger=WandbLogger(name=f"{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}", project=cfg.logger.wandb.project) if cfg.logger.wandb.use == True else False,
        max_epochs=cfg.train.epochs,
        precision=cfg.train.precision,
        accumulate_grad_batches=cfg.data.accumulate_grad_batches,
    )

    trainer.tune(module, datamodule=data_module)

    trainer.fit(module, datamodule=data_module, ckpt_path=f"/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}"+".ckpt" if os.path.exists(f"/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}"+".ckpt") else None)
    
    return trainer

  stdout_func(


In [4]:
KAGGLE_DIR = Path("/home/ziyang") / "kaggle"
BASE_DIR = KAGGLE_DIR / "hubmap-organ-segmentation"

INPUT_DIR = BASE_DIR / "input"
OUTPUT_DIR = BASE_DIR / "working"
CONFIG_DIR = BASE_DIR / "config"

COMPETITION_DATA_DIR = INPUT_DIR / "hubmap-organ-segmentation"

TRAIN_PREPARED_CSV_PATH = "train_prepared.csv"
VAL_PRED_PREPARED_CSV_PATH = "val_pred_prepared.csv"
TEST_PREPARED_CSV_PATH = "test_prepared.csv"

CONFIG_YAML_PATH = CONFIG_DIR / "default.yaml"

In [5]:
def add_path_to_df(df: pd.DataFrame, data_dir: Path, type_: str, stage: str) -> pd.DataFrame:
    ending = ".tiff" if type_ == "image" else ".npy"
    
    dir_ = str(data_dir / f"{stage}_{type_}s") if type_ == "image" else f"{stage}_{type_}s"
    df[type_] = dir_ + "/" + df["id"].astype(str) + ending
    return df


def add_paths_to_df(df: pd.DataFrame, data_dir: Path, stage: str) -> pd.DataFrame:
    df = add_path_to_df(df, data_dir, "image", stage)
    df = add_path_to_df(df, data_dir, "mask", stage)
    return df


def create_folds(df: pd.DataFrame, n_splits: int, random_seed: int) -> pd.DataFrame:
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
    for fold, (_, val_idx) in enumerate(skf.split(X=df, y=df["organ"])):
        df.loc[val_idx, "fold"] = fold

    return df


def prepare_data(data_dir: Path, stage: str, n_splits: int, random_seed: int) -> None:
    df = pd.read_csv(data_dir / f"{stage}.csv")
    df = add_paths_to_df(df, data_dir, stage)

    if stage == "train":
        df = create_folds(df, n_splits, random_seed)

    filename = f"{stage}_prepared.csv"
    df.to_csv(filename, index=False)

    print(f"Created {filename} with shape {df.shape}")

    return df

In [6]:
from utils import EasyConfig
cfg = EasyConfig()
cfg.load(CONFIG_YAML_PATH)

cfg.data.train_csv_path = os.path.join(os.getcwd(), TRAIN_PREPARED_CSV_PATH)
cfg.data.test_csv_path = os.path.join(os.getcwd(), TEST_PREPARED_CSV_PATH)

cfg_train = cfg.train
cfg_data = cfg.data
cfg_model = cfg.model

prepare_data(COMPETITION_DATA_DIR, "train", cfg_data.n_split, cfg.seed)
prepare_data(COMPETITION_DATA_DIR, "test", cfg_data.n_split, cfg.seed)

Created train_prepared.csv with shape (351, 13)
Created test_prepared.csv with shape (1, 9)


Unnamed: 0,id,organ,data_source,img_height,img_width,pixel_size,tissue_thickness,image,mask
0,10078,spleen,Hubmap,2023,2023,0.4945,4,/home/ziyang/kaggle/hubmap-organ-segmentation/input/hubmap-organ-segmentation/test_images/10078.tiff,test_masks/10078.npy


In [7]:
for fold in range(cfg_data.n_split):
    trainer = train(cfg, fold, cfg_data.train_csv_path, cfg_data.test_csv_path)

Global seed set to 42
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mziyangye[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                       | Params
---------------------------------------------------------
0 | model     | UnetPlusPlus_with_ASPP_FPN | 85.0 M
1 | loss_fn   | SymmetricLovaszLoss        | 0     
2 | dice_soft | Dice_soft                  | 0     
3 | dice_th   | Dice_threshold             | 0     
---------------------------------------------------------
85.0 M    Trainable params
0         Non-trainable params
85.0 M    Total params
169.961   Total estimated model params size

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 16: 'val_dice_th' reached 0.24851 (best 0.24851), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 32: 'val_dice_th' reached 0.27325 (best 0.27325), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 48: 'val_dice_th' reached 0.29979 (best 0.29979), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 64: 'val_dice_th' reached 0.33077 (best 0.33077), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 80: 'val_dice_th' reached 0.36643 (best 0.36643), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 96: 'val_dice_th' reached 0.38305 (best 0.38305), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 112: 'val_dice_th' reached 0.41278 (best 0.41278), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 7, global step 128: 'val_dice_th' reached 0.43521 (best 0.43521), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 8, global step 144: 'val_dice_th' reached 0.45153 (best 0.45153), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 9, global step 160: 'val_dice_th' reached 0.46790 (best 0.46790), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 10, global step 176: 'val_dice_th' reached 0.48217 (best 0.48217), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 11, global step 192: 'val_dice_th' reached 0.49689 (best 0.49689), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 12, global step 208: 'val_dice_th' reached 0.51479 (best 0.51479), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 13, global step 224: 'val_dice_th' reached 0.53133 (best 0.53133), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 14, global step 240: 'val_dice_th' reached 0.54751 (best 0.54751), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 15, global step 256: 'val_dice_th' reached 0.56502 (best 0.56502), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 16, global step 272: 'val_dice_th' reached 0.58279 (best 0.58279), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 288: 'val_dice_th' reached 0.59801 (best 0.59801), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 18, global step 304: 'val_dice_th' reached 0.61043 (best 0.61043), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 320: 'val_dice_th' reached 0.62183 (best 0.62183), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 20, global step 336: 'val_dice_th' reached 0.62996 (best 0.62996), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 21, global step 352: 'val_dice_th' reached 0.64013 (best 0.64013), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 22, global step 368: 'val_dice_th' reached 0.65175 (best 0.65175), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 384: 'val_dice_th' reached 0.66142 (best 0.66142), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 24, global step 400: 'val_dice_th' reached 0.66795 (best 0.66795), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 25, global step 416: 'val_dice_th' reached 0.67293 (best 0.67293), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 26, global step 432: 'val_dice_th' reached 0.68018 (best 0.68018), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 27, global step 448: 'val_dice_th' reached 0.68603 (best 0.68603), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 464: 'val_dice_th' reached 0.68995 (best 0.68995), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 480: 'val_dice_th' reached 0.69566 (best 0.69566), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 30, global step 496: 'val_dice_th' reached 0.69942 (best 0.69942), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 31, global step 512: 'val_dice_th' reached 0.70383 (best 0.70383), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 32, global step 528: 'val_dice_th' reached 0.70792 (best 0.70792), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 33, global step 544: 'val_dice_th' reached 0.71402 (best 0.71402), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 34, global step 560: 'val_dice_th' reached 0.71763 (best 0.71763), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 35, global step 576: 'val_dice_th' reached 0.72028 (best 0.72028), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 36, global step 592: 'val_dice_th' reached 0.72342 (best 0.72342), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 37, global step 608: 'val_dice_th' reached 0.72545 (best 0.72545), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 38, global step 624: 'val_dice_th' reached 0.72962 (best 0.72962), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 39, global step 640: 'val_dice_th' reached 0.73206 (best 0.73206), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 40, global step 656: 'val_dice_th' reached 0.73438 (best 0.73438), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 41, global step 672: 'val_dice_th' reached 0.73731 (best 0.73731), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 42, global step 688: 'val_dice_th' reached 0.74058 (best 0.74058), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 43, global step 704: 'val_dice_th' reached 0.74392 (best 0.74392), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 44, global step 720: 'val_dice_th' reached 0.74667 (best 0.74667), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 45, global step 736: 'val_dice_th' reached 0.74999 (best 0.74999), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 46, global step 752: 'val_dice_th' reached 0.75213 (best 0.75213), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 47, global step 768: 'val_dice_th' reached 0.75493 (best 0.75493), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 48, global step 784: 'val_dice_th' reached 0.75773 (best 0.75773), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 49, global step 800: 'val_dice_th' reached 0.75991 (best 0.75991), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 50, global step 816: 'val_dice_th' reached 0.76012 (best 0.76012), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 51, global step 832: 'val_dice_th' reached 0.76105 (best 0.76105), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 52, global step 848: 'val_dice_th' reached 0.76284 (best 0.76284), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 53, global step 864: 'val_dice_th' reached 0.76316 (best 0.76316), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 54, global step 880: 'val_dice_th' reached 0.76512 (best 0.76512), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 55, global step 896: 'val_dice_th' reached 0.76614 (best 0.76614), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 56, global step 912: 'val_dice_th' reached 0.76803 (best 0.76803), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 57, global step 928: 'val_dice_th' reached 0.76855 (best 0.76855), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 58, global step 944: 'val_dice_th' reached 0.76955 (best 0.76955), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 59, global step 960: 'val_dice_th' reached 0.76979 (best 0.76979), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 60, global step 976: 'val_dice_th' reached 0.77040 (best 0.77040), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 61, global step 992: 'val_dice_th' reached 0.77162 (best 0.77162), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 62, global step 1008: 'val_dice_th' reached 0.77284 (best 0.77284), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 63, global step 1024: 'val_dice_th' reached 0.77426 (best 0.77426), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 64, global step 1040: 'val_dice_th' reached 0.77516 (best 0.77516), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 65, global step 1056: 'val_dice_th' reached 0.77540 (best 0.77540), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 66, global step 1072: 'val_dice_th' reached 0.77607 (best 0.77607), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 67, global step 1088: 'val_dice_th' reached 0.77686 (best 0.77686), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 68, global step 1104: 'val_dice_th' reached 0.77731 (best 0.77731), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 69, global step 1120: 'val_dice_th' reached 0.77769 (best 0.77769), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 70, global step 1136: 'val_dice_th' reached 0.77814 (best 0.77814), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 71, global step 1152: 'val_dice_th' reached 0.77819 (best 0.77819), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 72, global step 1168: 'val_dice_th' reached 0.77870 (best 0.77870), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 73, global step 1184: 'val_dice_th' reached 0.77930 (best 0.77930), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 74, global step 1200: 'val_dice_th' reached 0.78021 (best 0.78021), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 75, global step 1216: 'val_dice_th' reached 0.78112 (best 0.78112), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 76, global step 1232: 'val_dice_th' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 77, global step 1248: 'val_dice_th' reached 0.78210 (best 0.78210), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 78, global step 1264: 'val_dice_th' reached 0.78259 (best 0.78259), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 79, global step 1280: 'val_dice_th' reached 0.78357 (best 0.78357), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 80, global step 1296: 'val_dice_th' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 81, global step 1312: 'val_dice_th' reached 0.78430 (best 0.78430), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 82, global step 1328: 'val_dice_th' reached 0.78516 (best 0.78516), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 83, global step 1344: 'val_dice_th' reached 0.78597 (best 0.78597), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 84, global step 1360: 'val_dice_th' reached 0.78615 (best 0.78615), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 85, global step 1376: 'val_dice_th' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 86, global step 1392: 'val_dice_th' reached 0.78667 (best 0.78667), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 87, global step 1408: 'val_dice_th' reached 0.78740 (best 0.78740), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 88, global step 1424: 'val_dice_th' reached 0.78845 (best 0.78845), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 89, global step 1440: 'val_dice_th' reached 0.78883 (best 0.78883), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 90, global step 1456: 'val_dice_th' reached 0.78934 (best 0.78934), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 91, global step 1472: 'val_dice_th' reached 0.78948 (best 0.78948), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 92, global step 1488: 'val_dice_th' reached 0.79000 (best 0.79000), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 93, global step 1504: 'val_dice_th' reached 0.79040 (best 0.79040), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 94, global step 1520: 'val_dice_th' reached 0.79145 (best 0.79145), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 95, global step 1536: 'val_dice_th' reached 0.79211 (best 0.79211), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 96, global step 1552: 'val_dice_th' reached 0.79275 (best 0.79275), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 97, global step 1568: 'val_dice_th' reached 0.79382 (best 0.79382), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 98, global step 1584: 'val_dice_th' reached 0.79424 (best 0.79424), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 99, global step 1600: 'val_dice_th' reached 0.79495 (best 0.79495), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 100, global step 1616: 'val_dice_th' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 101, global step 1632: 'val_dice_th' reached 0.79517 (best 0.79517), saving model to '/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b7_0.ckpt' as top 1


In [None]:
import monai
from utils import mask2rle
@torch.no_grad()
def create_pred_df(module, dataloader, threshold):
    ids = []
    rles = []
    for batch in tqdm(dataloader):
        id_ = batch["id"].numpy()[0]
        height = batch["img_height"].numpy()[0]
        width = batch["img_width"].numpy()[0]
        
        images = batch["image"].to(module.device)
        outputs = module(images)[0]
        
        post_pred_transform = monai.transforms.Compose(
            [
                monai.transforms.Resize(spatial_size=(height, width), mode="nearest"),
                monai.transforms.Activations(sigmoid=True),
                monai.transforms.AsDiscrete(threshold=threshold),
            ]
        )
        
        mask = post_pred_transform(outputs).to(torch.uint8).cpu().detach().numpy()[0]
        
        rle = mask2rle(mask)
        
        ids.append(id_)
        rles.append(rle)
        
    return pd.DataFrame({"id": ids, "rle": rles})


def infer(
    checkpoint_path: str,
    spatial_size: int,
    num_workers: int,
    device: str = 'cuda',
    train_csv_path: str = TRAIN_PREPARED_CSV_PATH,
    test_csv_path: str = TEST_PREPARED_CSV_PATH,
    threshold: float = 0.5,
):
    module = LitModule(cfg).load_eval_checkpoint(checkpoint_path, device)

    data_module = LitDataModule(
        train_csv_path=train_csv_path,
        test_csv_path=test_csv_path,
        spatial_size=spatial_size,
        val_fold=0,
        batch_size=1,
        num_workers=num_workers,
    )
    data_module.setup()
    
    val_dataloader = data_module.val_dataloader()
    test_dataloader = data_module.test_dataloader()
    
    val_pred_df = create_pred_df(module, val_dataloader, threshold)
    test_pred_df = create_pred_df(module, test_dataloader, threshold)
    
    return val_pred_df, test_pred_df