In [None]:
import sys, os
sys.path.append("/home/ziyang/kaggle/hubmap-organ-segmentation")

In [None]:
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple

import timm
import torch
import torch.nn as nn
import pandas as pd

import segmentation_models_pytorch as smp


from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from data import LitDataModule
from model import LitModule

def train(
    cfg,
    fold: int,
    train_csv_path: str,
    test_csv_path: str,
) -> None:
    pl.seed_everything(cfg.seed)

    data_module = LitDataModule(
        val_fold=fold,
        train_csv_path=train_csv_path,
        test_csv_path=test_csv_path,
        spatial_size=cfg.data.spatial_size,
        batch_size=cfg.data.batch_size,
        num_workers=cfg.data.num_workers,
    )

    data_module.setup()

    module = LitModule(cfg)

    model_checkpoint = ModelCheckpoint(cfg.train.checkpoint_dir,
                                        monitor="val_dice_th",
                                        mode="max",
                                        verbose=True,
                                        filename=f"{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}",
                                        )

    trainer = pl.Trainer(
        default_root_dir=cfg.train.checkpoint_dir,
        gpus=cfg.train.gpus,
        benchmark=True,
        deterministic=False,
        callbacks=[model_checkpoint],
        limit_train_batches=1.0,
        limit_val_batches=1.0,
        log_every_n_steps=5,
        logger=WandbLogger(name=f"{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}", project=cfg.logger.wandb.project) if cfg.logger.wandb.use == True else False,
        max_epochs=cfg.train.epochs,
        precision=cfg.train.precision,
        accumulate_grad_batches=cfg.data.accumulate_grad_batches,
    )

    trainer.tune(module, datamodule=data_module)

    trainer.fit(module, datamodule=data_module, ckpt_path=f"/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}"+".ckpt" if os.path.exists(f"/home/ziyang/kaggle/hubmap-organ-segmentation/notebooks/tmp/checkpoint/{module.model.__class__.__name__}_{cfg.model.backbone}_{fold}"+".ckpt") else None)
    
    return trainer

In [None]:
KAGGLE_DIR = Path("/home/ziyang") / "kaggle"
BASE_DIR = KAGGLE_DIR / "hubmap-organ-segmentation"

INPUT_DIR = BASE_DIR / "input"
OUTPUT_DIR = BASE_DIR / "working"
CONFIG_DIR = BASE_DIR / "config"

COMPETITION_DATA_DIR = INPUT_DIR / "hubmap-organ-segmentation"

TRAIN_PREPARED_CSV_PATH = "train_prepared.csv"
VAL_PRED_PREPARED_CSV_PATH = "val_pred_prepared.csv"
TEST_PREPARED_CSV_PATH = "test_prepared.csv"

CONFIG_YAML_PATH = CONFIG_DIR / "default.yaml"

In [None]:
def add_path_to_df(df: pd.DataFrame, data_dir: Path, type_: str, stage: str) -> pd.DataFrame:
    ending = ".tiff" if type_ == "image" else ".npy"
    
    dir_ = str(data_dir / f"{stage}_{type_}s") if type_ == "image" else f"{stage}_{type_}s"
    df[type_] = dir_ + "/" + df["id"].astype(str) + ending
    return df


def add_paths_to_df(df: pd.DataFrame, data_dir: Path, stage: str) -> pd.DataFrame:
    df = add_path_to_df(df, data_dir, "image", stage)
    df = add_path_to_df(df, data_dir, "mask", stage)
    return df


def create_folds(df: pd.DataFrame, n_splits: int, random_seed: int) -> pd.DataFrame:
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
    for fold, (_, val_idx) in enumerate(skf.split(X=df, y=df["organ"])):
        df.loc[val_idx, "fold"] = fold

    return df


def prepare_data(data_dir: Path, stage: str, n_splits: int, random_seed: int) -> None:
    df = pd.read_csv(data_dir / f"{stage}.csv")
    df = add_paths_to_df(df, data_dir, stage)

    if stage == "train":
        df = create_folds(df, n_splits, random_seed)

    filename = f"{stage}_prepared.csv"
    df.to_csv(filename, index=False)

    print(f"Created {filename} with shape {df.shape}")

    return df

In [None]:
from utils import EasyConfig
cfg = EasyConfig()
cfg.load(CONFIG_YAML_PATH)

cfg.data.train_csv_path = os.path.join(os.getcwd(), TRAIN_PREPARED_CSV_PATH)
cfg.data.test_csv_path = os.path.join(os.getcwd(), TEST_PREPARED_CSV_PATH)

cfg_train = cfg.train
cfg_data = cfg.data
cfg_model = cfg.model

prepare_data(COMPETITION_DATA_DIR, "train", cfg_data.n_split, cfg.seed)
prepare_data(COMPETITION_DATA_DIR, "test", cfg_data.n_split, cfg.seed)

In [None]:
for fold in range(cfg_data.n_split):
    trainer = train(cfg, fold, cfg_data.train_csv_path, cfg_data.test_csv_path)

In [None]:
import monai
from utils import mask2rle
@torch.no_grad()
def create_pred_df(module, dataloader, threshold):
    ids = []
    rles = []
    for batch in tqdm(dataloader):
        id_ = batch["id"].numpy()[0]
        height = batch["img_height"].numpy()[0]
        width = batch["img_width"].numpy()[0]
        
        images = batch["image"].to(module.device)
        outputs = module(images)[0]
        
        post_pred_transform = monai.transforms.Compose(
            [
                monai.transforms.Resize(spatial_size=(height, width), mode="nearest"),
                monai.transforms.Activations(sigmoid=True),
                monai.transforms.AsDiscrete(threshold=threshold),
            ]
        )
        
        mask = post_pred_transform(outputs).to(torch.uint8).cpu().detach().numpy()[0]
        
        rle = mask2rle(mask)
        
        ids.append(id_)
        rles.append(rle)
        
    return pd.DataFrame({"id": ids, "rle": rles})


def infer(
    checkpoint_path: str,
    spatial_size: int,
    num_workers: int,
    device: str = 'cuda',
    train_csv_path: str = TRAIN_PREPARED_CSV_PATH,
    test_csv_path: str = TEST_PREPARED_CSV_PATH,
    threshold: float = 0.5,
):
    module = LitModule(cfg).load_eval_checkpoint(checkpoint_path, device)

    data_module = LitDataModule(
        train_csv_path=train_csv_path,
        test_csv_path=test_csv_path,
        spatial_size=spatial_size,
        val_fold=0,
        batch_size=1,
        num_workers=num_workers,
    )
    data_module.setup()
    
    val_dataloader = data_module.val_dataloader()
    test_dataloader = data_module.test_dataloader()
    
    val_pred_df = create_pred_df(module, val_dataloader, threshold)
    test_pred_df = create_pred_df(module, test_dataloader, threshold)
    
    return val_pred_df, test_pred_df