In [None]:
import sys, os
from pathlib import Path
ROOT_DIR = Path(os.path.abspath(os.path.join(os.getcwd(), "..")))
BASE_DIR = ROOT_DIR / "pytorch-lightning"
sys.path.append(ROOT_DIR)
sys.path.append(BASE_DIR)

In [None]:
import monai
import torch
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm

from data import LitDataModule
from model import LitModule
from utils import mask2rle

In [None]:

@torch.no_grad()
def create_pred_df(module, dataloader, threshold):
    ids = []
    rles = []
    for batch in tqdm(dataloader):
        id_ = batch["id"].numpy()[0]
        height = batch["img_height"].numpy()[0]
        width = batch["img_width"].numpy()[0]
        
        images = batch["image"].to(module.device)
        outputs = module(images)[0]
        
        post_pred_transform = monai.transforms.Compose(
            [
                monai.transforms.Resize(spatial_size=(height, width), mode="nearest"),
                monai.transforms.Activations(sigmoid=True),
                monai.transforms.AsDiscrete(threshold=threshold),
            ]
        )
        
        mask = post_pred_transform(outputs).to(torch.uint8).cpu().detach().numpy()[0]
        
        rle = mask2rle(mask)
        
        ids.append(id_)
        rles.append(rle)
        
    return pd.DataFrame({"id": ids, "rle": rles})

In [None]:
def infer(
    checkpoint_path,
    cfg,
    train_csv_path: str = None,
    test_csv_path: str = None,
    threshold: float = 0.5,
    fold: int = 0,
    device: str = 'cuda',
):
    data_module = LitDataModule(
        val_fold=fold,
        train_csv_path=train_csv_path,
        test_csv_path=test_csv_path,
        spatial_size=cfg.data.spatial_size,
        batch_size=cfg.data.batch_size,
        num_workers=cfg.data.num_workers,
    )

    data_module.setup()

    module = LitModule(cfg).load_eval_checkpoint(checkpoint_path, device)
    torch.save(module.state_dict(), OUTPUT_DIR / f"{module.model.__class__.__name__}_{cfg.model.backbone}.pth")

    test_dataloader = data_module.test_dataloader()
    test_pred_df = create_pred_df(module, test_dataloader, threshold)
    return test_pred_df

In [None]:
INPUT_DIR = BASE_DIR / "input"
OUTPUT_DIR = ROOT_DIR / "working"
CONFIG_DIR = BASE_DIR / "config"

COMPETITION_DATA_DIR = INPUT_DIR / "hubmap-organ-segmentation"

TRAIN_PREPARED_CSV_PATH = "train_prepared.csv"
VAL_PRED_PREPARED_CSV_PATH = "val_pred_prepared.csv"
TEST_PREPARED_CSV_PATH = "test_prepared.csv"

CONFIG_YAML_PATH = CONFIG_DIR / "default.yaml"

In [None]:
def add_path_to_df(df: pd.DataFrame, data_dir: Path, type_: str, stage: str) -> pd.DataFrame:
    ending = ".tiff" if type_ == "image" else ".npy"
    
    dir_ = str(data_dir / f"{stage}_{type_}s") if type_ == "image" else f"{stage}_{type_}s"
    df[type_] = dir_ + "/" + df["id"].astype(str) + ending
    return df


def add_paths_to_df(df: pd.DataFrame, data_dir: Path, stage: str) -> pd.DataFrame:
    df = add_path_to_df(df, data_dir, "image", stage)
    df = add_path_to_df(df, data_dir, "mask", stage)
    return df


def create_folds(df: pd.DataFrame, n_splits: int, random_seed: int) -> pd.DataFrame:
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
    for fold, (_, val_idx) in enumerate(skf.split(X=df, y=df["organ"])):
        df.loc[val_idx, "fold"] = fold

    return df


def prepare_data(data_dir: Path, stage: str, n_splits: int, random_seed: int) -> None:
    df = pd.read_csv(data_dir / f"{stage}.csv")
    df = add_paths_to_df(df, data_dir, stage)

    if stage == "train":
        df = create_folds(df, n_splits, random_seed)

    filename = f"{stage}_prepared.csv"
    df.to_csv(filename, index=False)

    print(f"Created {filename} with shape {df.shape}")

    return df

In [None]:
from utils import EasyConfig
cfg = EasyConfig()
cfg.load(CONFIG_YAML_PATH)
cfg.data.train_csv_path = os.path.join(os.getcwd(), TRAIN_PREPARED_CSV_PATH)
cfg.data.test_csv_path = os.path.join(os.getcwd(), TEST_PREPARED_CSV_PATH)

cfg_train = cfg.train
cfg_data = cfg.data
cfg_model = cfg.model

prepare_data(COMPETITION_DATA_DIR, "train", cfg_data.n_split, cfg.seed)
prepare_data(COMPETITION_DATA_DIR, "test", cfg_data.n_split, cfg.seed)

In [None]:
test_pred_df = infer('tmp/checkpoint/UnetPlusPlus_with_ASPP_FPN_efficientnet-b5_2.ckpt', cfg, cfg_data.train_csv_path, cfg_data.test_csv_path)
test_pred_df.to_csv(os.path.join(OUTPUT_DIR, "submission.csv"), index=False)