In [6]:
import json
import pandas as pd
import numpy as np
import numexpr as ne
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
import wandb
from dotenv import load_dotenv
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.cuda.amp as amp

from src.model.model import save_model, load_inference_model
from src.dataset.df import check_dataset
from src.utils.common import set_seed

from src.experiment.inference import inference
from src.experiment.initialize import init_exp
from src.utils.metrics import compute_surface_dice_score_from_volume
from src.model.model import build_model
from src.model.loss import get_lossfn
from src.utils.metrics import get_metrics
from src.model.scheduler import get_scheduler
from src.dataset.common import get_train_dataset

In [8]:
class cfg:
    debug = False
    check_dataset = False

    # = data CFG ====================================================

    dataset_path = "/kaggle/working/dataset/train01_xy_256_128_z_1_2/"
    train_dataset = "Base2dDataset"
    negative_sample_rate = 0.2

    # = experiment CFG =================================================

    project = "SenNet"
    exp_name = os.path.basename(os.getcwd())
    notes = "preliminary model"

    # = model CFG ======================================================

    model_arch = "Unet"
    backbone = "se_resnext50_32x4d"
    in_chans = 1
    target_size = 1

    # = training CFG ===================================================

    epochs = 60

    train_batch_size = 64
    valid_batch_size = train_batch_size

    loss = "DiceLoss"
    metrics = "Dice"
    lr = 5e-4
    num_workers = 12

    # = augmentation ===================================================

    image_size = 256
    train_aug = [
        A.RandomRotate90(p=0.5),
        A.RandomGamma(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug = [
        ToTensorV2(transpose_mask=True),
    ]

    # =============== inference ========================================

    test_dataset = "BaseInferenceDataset"
    stride = image_size // 2
    drop_egde_pixel = 32


load_dotenv("/kaggle/key.env")
set_seed()

In [9]:
def filter_dataset(df):
    # trainのうちlabelが全くないものは90%の確率で除外
    df["random"] = np.random.rand(len(df))
    df = df[(df["sum"] > 0) | (df["fold0"] == "valid") | (df["random"] < cfg.negative_sample_rate)]
    df = df.reset_index(drop=True)
    df = df.drop(["random"], axis=1)
    return df


df = pd.read_csv(f"{cfg.dataset_path}/dataset.csv")
df = filter_dataset(df)
# df = df.sample(1000).reset_index(drop=True)
if cfg.debug:
    df = df.sample(10000).reset_index(drop=True)
display(df)

if cfg.check_dataset:
    check_dataset(df, cfg)

  df = pd.read_csv(f"{cfg.dataset_path}/dataset.csv")


Unnamed: 0,image_path,label_path,fname,kidney,x,y,z,std,sum,fold0,fold1,fold2
0,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z0_std0039_sum0,kidney_1_dense,0,0,0,39,0,valid,train,train
1,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z100_std0037_sum0,kidney_1_dense,0,0,100,37,0,valid,train,train
2,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z102_std0036_sum0,kidney_1_dense,0,0,102,36,0,valid,train,train
3,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z104_std0036_sum0,kidney_1_dense,0,0,104,36,0,valid,train,train
4,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z106_std0036_sum0,kidney_1_dense,0,0,106,36,0,valid,train,train
...,...,...,...,...,...,...,...,...,...,...,...,...
481727,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x896_y256_z990_std0261_sum1352,kidney_3_sparse,896,256,990,261,1352,train,,
481728,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x896_y256_z992_std0262_sum1473,kidney_3_sparse,896,256,992,262,1473,train,,
481729,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x896_y256_z994_std0263_sum1554,kidney_3_sparse,896,256,994,263,1554,train,,
481730,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x896_y256_z996_std0264_sum1656,kidney_3_sparse,896,256,996,264,1656,train,,


In [4]:
if cfg.debug:
    print("!!!Debug mode!!!\n")
    cfg.epochs = 5

for fold in [2]:
    dataset = get_train_dataset(cfg)

    train_df = df[df[f"fold{fold}"] == "train"]
    valid_df = df[df[f"fold{fold}"] == "valid"]

    train_dataset = dataset(train_df, cfg, is_train=True)
    valid_dataset = dataset(valid_df, cfg, is_train=False)

    train_dataloader = DataLoader(train_dataset, batch_size=cfg.train_batch_size, num_workers=cfg.num_workers, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.valid_batch_size, num_workers=cfg.num_workers, shuffle=False)

    model = build_model(cfg.model_arch, cfg.backbone, cfg.in_chans, cfg.target_size)
    scaler = torch.cuda.amp.GradScaler()
    criterion = get_lossfn(cfg)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    scheduler = get_scheduler(cfg, optimizer)
    metrics = get_metrics(cfg)

    # model, scaler, criterion, optimizer, scheduler, metrics = init_model(cfg)
    slacknotify = init_exp(fold, cfg)

    path_best = f"./{cfg.exp_name}/{cfg.exp_name}_best_fold{fold}.pth"
    path_last = f"./{cfg.exp_name}/{cfg.exp_name}_last_fold{fold}.pth"

    best_loss = float("inf")
    for epoch in range(cfg.epochs):
        model.train()
        total_loss = 0.0
        pbar_train = tqdm(enumerate(train_dataloader), total=len(train_dataloader), bar_format="{l_bar}{bar:10}{r_bar}{bar:-0b}")

        for i, (images, masks) in pbar_train:
            images, masks = images.cuda(), masks.cuda()
            optimizer.zero_grad()

            with amp.autocast():
                preds = model(images)
                loss = criterion(preds, masks)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                total_loss += loss.detach().item()

            loss_ = total_loss / (i + 1)
            lr = f"LR : {scheduler.get_lr()[0]:.2E}"
            gpu_mem = f"Mem : {torch.cuda.memory_reserved() / 1E9:.3g}GB"
            pbar_train.set_description(("%10s  " * 3 + "%10s") % (f"Epoch {epoch}/{cfg.epochs}", gpu_mem, lr, f"Loss: {loss_:.4f}"))

        train_loss = loss_
        scheduler.step()
        wandb.log({"epoch": epoch, "train_loss": train_loss})

        model.eval()
        total_loss = 0.0
        pbar_val = tqdm(enumerate(valid_dataloader), total=len(valid_dataloader), bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")

        for i, (images, masks) in pbar_val:
            images, masks = images.cuda(), masks.cuda()
            with torch.no_grad():
                preds = model(images)
                loss = criterion(preds, masks)
                total_loss += loss.item()

            loss_ = total_loss / (i + 1)
            pbar_val.set_description(("%10s") % (f"Val Loss: {loss_:.4f}"))
        valid_loss = loss_
        wandb.log({"epoch": epoch, "valid_loss": valid_loss})

        if valid_loss < best_loss:
            print(f"loss : {valid_loss:.4f}\tSAVED MODEL\n")
            slacknotify.send_reply(f"epoch : {epoch}\tscore : {valid_loss:.4f}\tBEST")
            best_loss = valid_loss
            save_model(model, cfg, path_best, loss=loss)
        else:
            print(f"loss : {valid_loss:.4f}\n")
            slacknotify.send_reply(f"epoch : {epoch}\tscore : {valid_loss:.4f}")

    save_model(model, cfg, path_last, loss=valid_loss)
    wandb.config.update({"last_loss": valid_loss, "best_loss": best_loss})

    slacknotify.send_reply(f"{cfg.exp_name}_fold{fold} training finished\nbest loss : {best_loss:.4f} last loss : {loss_:.4f}", True)

    if wandb.run:
        wandb.finish()

model_arch:  Unet
backbone:  se_resnext50_32x4d


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwelshonionman[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0/60  Mem : 8.95GB  LR : 5.00E-04  Loss: 0.1188: 100%|██████████| 4352/4352 [13:36<00:00,  5.33it/s]
Val Loss: 0.2281: 100%|██████████| 2503/2503 [04:37<00:00,  9.03it/s]


loss : 0.2281	SAVED MODEL



Epoch 1/60  Mem : 9.95GB  LR : 5.00E-03  Loss: 0.1245: 100%|██████████| 4352/4352 [13:13<00:00,  5.48it/s]
Val Loss: 0.2013: 100%|██████████| 2503/2503 [04:34<00:00,  9.11it/s]


loss : 0.2013	SAVED MODEL



Epoch 2/60  Mem : 9.95GB  LR : 5.00E-04  Loss: 0.0789: 100%|██████████| 4352/4352 [13:14<00:00,  5.48it/s]
Val Loss: 0.1652: 100%|██████████| 2503/2503 [04:35<00:00,  9.10it/s]


loss : 0.1652	SAVED MODEL



Epoch 3/60  Mem : 9.95GB  LR : 5.00E-04  Loss: 0.0735: 100%|██████████| 4352/4352 [13:15<00:00,  5.47it/s]
Val Loss: 0.1678: 100%|██████████| 2503/2503 [04:35<00:00,  9.08it/s]


loss : 0.1678



Epoch 4/60  Mem : 9.95GB  LR : 4.99E-04  Loss: 0.0706: 100%|██████████| 4352/4352 [13:17<00:00,  5.46it/s]
Val Loss: 0.1621: 100%|██████████| 2503/2503 [04:36<00:00,  9.04it/s]


loss : 0.1621	SAVED MODEL



Epoch 5/60  Mem : 9.95GB  LR : 4.97E-04  Loss: 0.0692: 100%|██████████| 4352/4352 [13:13<00:00,  5.49it/s]
Val Loss: 0.1737: 100%|██████████| 2503/2503 [04:36<00:00,  9.06it/s]


loss : 0.1737



Epoch 6/60  Mem : 9.95GB  LR : 4.95E-04  Loss: 0.0678: 100%|██████████| 4352/4352 [13:15<00:00,  5.47it/s]
Val Loss: 0.1574: 100%|██████████| 2503/2503 [04:34<00:00,  9.10it/s]


loss : 0.1574	SAVED MODEL



Epoch 7/60  Mem : 9.95GB  LR : 4.91E-04  Loss: 0.0673: 100%|██████████| 4352/4352 [13:16<00:00,  5.46it/s]
Val Loss: 0.1561: 100%|██████████| 2503/2503 [04:35<00:00,  9.10it/s]


loss : 0.1561	SAVED MODEL



Epoch 8/60  Mem : 9.95GB  LR : 4.88E-04  Loss: 0.0666: 100%|██████████| 4352/4352 [13:32<00:00,  5.36it/s]
Val Loss: 0.1633: 100%|██████████| 2503/2503 [04:33<00:00,  9.15it/s]


loss : 0.1633



Epoch 9/60  Mem : 9.95GB  LR : 4.83E-04  Loss: 0.0657: 100%|██████████| 4352/4352 [13:11<00:00,  5.50it/s]
Val Loss: 0.1672: 100%|██████████| 2503/2503 [04:33<00:00,  9.16it/s]


loss : 0.1672



Epoch 10/60  Mem : 9.95GB  LR : 4.78E-04  Loss: 0.0654: 100%|██████████| 4352/4352 [13:10<00:00,  5.50it/s]
Val Loss: 0.1621: 100%|██████████| 2503/2503 [04:50<00:00,  8.63it/s]


loss : 0.1621



Epoch 11/60  Mem : 9.95GB  LR : 4.73E-04  Loss: 0.0648: 100%|██████████| 4352/4352 [14:11<00:00,  5.11it/s]
Val Loss: 0.1610: 100%|██████████| 2503/2503 [05:08<00:00,  8.11it/s]


loss : 0.1610



Epoch 12/60  Mem : 9.95GB  LR : 4.67E-04  Loss: 0.0639: 100%|██████████| 4352/4352 [14:13<00:00,  5.10it/s]
Val Loss: 0.1710: 100%|██████████| 2503/2503 [04:29<00:00,  9.27it/s]


loss : 0.1710



Epoch 13/60  Mem : 9.95GB  LR : 4.60E-04  Loss: 0.0639: 100%|██████████| 4352/4352 [12:59<00:00,  5.58it/s]
Val Loss: 0.1710: 100%|██████████| 2503/2503 [04:26<00:00,  9.40it/s]


loss : 0.1710



Epoch 14/60  Mem : 9.95GB  LR : 4.52E-04  Loss: 0.0631: 100%|██████████| 4352/4352 [12:15<00:00,  5.92it/s]
Val Loss: 0.1508: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1508	SAVED MODEL



Epoch 15/60  Mem : 9.95GB  LR : 4.44E-04  Loss: 0.0628: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1557: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1557



Epoch 16/60  Mem : 9.95GB  LR : 4.36E-04  Loss: 0.0625: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1652: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1652



Epoch 17/60  Mem : 9.95GB  LR : 4.27E-04  Loss: 0.0622: 100%|██████████| 4352/4352 [12:08<00:00,  5.97it/s]
Val Loss: 0.1563: 100%|██████████| 2503/2503 [04:08<00:00, 10.06it/s]


loss : 0.1563



Epoch 18/60  Mem : 9.95GB  LR : 4.17E-04  Loss: 0.0623: 100%|██████████| 4352/4352 [12:10<00:00,  5.96it/s]
Val Loss: 0.1656: 100%|██████████| 2503/2503 [04:08<00:00, 10.06it/s]


loss : 0.1656



Epoch 19/60  Mem : 9.95GB  LR : 4.07E-04  Loss: 0.0616: 100%|██████████| 4352/4352 [12:06<00:00,  5.99it/s]
Val Loss: 0.1688: 100%|██████████| 2503/2503 [04:09<00:00, 10.04it/s]


loss : 0.1688



Epoch 20/60  Mem : 9.95GB  LR : 3.97E-04  Loss: 0.0613: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1793: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1793



Epoch 21/60  Mem : 9.95GB  LR : 3.86E-04  Loss: 0.0610: 100%|██████████| 4352/4352 [12:04<00:00,  6.00it/s]
Val Loss: 0.1796: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1796



Epoch 22/60  Mem : 9.95GB  LR : 3.75E-04  Loss: 0.0607: 100%|██████████| 4352/4352 [12:08<00:00,  5.97it/s]
Val Loss: 0.1539: 100%|██████████| 2503/2503 [04:08<00:00, 10.06it/s]


loss : 0.1539



Epoch 23/60  Mem : 9.95GB  LR : 3.64E-04  Loss: 0.0602: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1630: 100%|██████████| 2503/2503 [04:08<00:00, 10.05it/s]


loss : 0.1630



Epoch 24/60  Mem : 9.95GB  LR : 3.52E-04  Loss: 0.0601: 100%|██████████| 4352/4352 [12:10<00:00,  5.96it/s]
Val Loss: 0.1663: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1663



Epoch 25/60  Mem : 9.95GB  LR : 3.40E-04  Loss: 0.0600: 100%|██████████| 4352/4352 [12:05<00:00,  6.00it/s]
Val Loss: 0.1625: 100%|██████████| 2503/2503 [04:09<00:00, 10.04it/s]


loss : 0.1625



Epoch 26/60  Mem : 9.95GB  LR : 3.27E-04  Loss: 0.0596: 100%|██████████| 4352/4352 [12:16<00:00,  5.91it/s]
Val Loss: 0.1675: 100%|██████████| 2503/2503 [04:08<00:00, 10.05it/s]


loss : 0.1675



Epoch 27/60  Mem : 9.95GB  LR : 3.15E-04  Loss: 0.0590: 100%|██████████| 4352/4352 [12:16<00:00,  5.91it/s]
Val Loss: 0.1773: 100%|██████████| 2503/2503 [04:09<00:00, 10.04it/s]


loss : 0.1773



Epoch 28/60  Mem : 9.95GB  LR : 3.02E-04  Loss: 0.0590: 100%|██████████| 4352/4352 [12:14<00:00,  5.92it/s]
Val Loss: 0.1924: 100%|██████████| 2503/2503 [04:08<00:00, 10.06it/s]


loss : 0.1924



Epoch 29/60  Mem : 9.95GB  LR : 2.89E-04  Loss: 0.0588: 100%|██████████| 4352/4352 [12:15<00:00,  5.92it/s]
Val Loss: 0.1580: 100%|██████████| 2503/2503 [04:08<00:00, 10.07it/s]


loss : 0.1580



Epoch 30/60  Mem : 9.95GB  LR : 2.76E-04  Loss: 0.0584: 100%|██████████| 4352/4352 [12:17<00:00,  5.90it/s]
Val Loss: 0.1633: 100%|██████████| 2503/2503 [04:09<00:00, 10.04it/s]


loss : 0.1633



Epoch 31/60  Mem : 9.95GB  LR : 2.63E-04  Loss: 0.0584: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1743: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1743



Epoch 32/60  Mem : 9.95GB  LR : 2.50E-04  Loss: 0.0578: 100%|██████████| 4352/4352 [12:20<00:00,  5.88it/s]
Val Loss: 0.1593: 100%|██████████| 2503/2503 [04:09<00:00, 10.05it/s]


loss : 0.1593



Epoch 33/60  Mem : 9.95GB  LR : 2.37E-04  Loss: 0.0579: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1516: 100%|██████████| 2503/2503 [04:09<00:00, 10.04it/s]


loss : 0.1516



Epoch 34/60  Mem : 9.95GB  LR : 2.24E-04  Loss: 0.0573: 100%|██████████| 4352/4352 [12:20<00:00,  5.87it/s]
Val Loss: 0.1656: 100%|██████████| 2503/2503 [04:05<00:00, 10.18it/s]


loss : 0.1656



Epoch 35/60  Mem : 9.95GB  LR : 2.11E-04  Loss: 0.0570: 100%|██████████| 4352/4352 [12:20<00:00,  5.88it/s]
Val Loss: 0.1721: 100%|██████████| 2503/2503 [04:06<00:00, 10.17it/s]


loss : 0.1721



Epoch 36/60  Mem : 9.95GB  LR : 1.98E-04  Loss: 0.0566: 100%|██████████| 4352/4352 [12:23<00:00,  5.85it/s]
Val Loss: 0.1545: 100%|██████████| 2503/2503 [04:06<00:00, 10.16it/s]


loss : 0.1545



Epoch 37/60  Mem : 9.95GB  LR : 1.85E-04  Loss: 0.0565: 100%|██████████| 4352/4352 [12:20<00:00,  5.88it/s]
Val Loss: 0.1575: 100%|██████████| 2503/2503 [04:06<00:00, 10.16it/s]


loss : 0.1575



Epoch 38/60  Mem : 9.95GB  LR : 1.73E-04  Loss: 0.0560: 100%|██████████| 4352/4352 [12:18<00:00,  5.89it/s]
Val Loss: 0.1733: 100%|██████████| 2503/2503 [04:06<00:00, 10.15it/s]


loss : 0.1733



Epoch 39/60  Mem : 9.95GB  LR : 1.60E-04  Loss: 0.0560: 100%|██████████| 4352/4352 [12:23<00:00,  5.86it/s]
Val Loss: 0.1682: 100%|██████████| 2503/2503 [04:06<00:00, 10.17it/s]


loss : 0.1682



Epoch 40/60  Mem : 9.95GB  LR : 1.48E-04  Loss: 0.0557: 100%|██████████| 4352/4352 [12:21<00:00,  5.87it/s]
Val Loss: 0.1723: 100%|██████████| 2503/2503 [04:05<00:00, 10.18it/s]


loss : 0.1723



Epoch 41/60  Mem : 9.95GB  LR : 1.37E-04  Loss: 0.0555: 100%|██████████| 4352/4352 [12:26<00:00,  5.83it/s]
Val Loss: 0.1784: 100%|██████████| 2503/2503 [04:06<00:00, 10.17it/s]


loss : 0.1784



Epoch 42/60  Mem : 9.95GB  LR : 1.25E-04  Loss: 0.0553: 100%|██████████| 4352/4352 [12:28<00:00,  5.81it/s]
Val Loss: 0.1732: 100%|██████████| 2503/2503 [04:06<00:00, 10.17it/s]


loss : 0.1732



Epoch 43/60  Mem : 9.95GB  LR : 1.14E-04  Loss: 0.0553: 100%|██████████| 4352/4352 [12:23<00:00,  5.85it/s]
Val Loss: 0.1657: 100%|██████████| 2503/2503 [04:14<00:00,  9.85it/s]


loss : 0.1657



Epoch 44/60  Mem : 9.95GB  LR : 1.03E-04  Loss: 0.0549: 100%|██████████| 4352/4352 [13:38<00:00,  5.32it/s]
Val Loss: 0.1685: 100%|██████████| 2503/2503 [04:14<00:00,  9.82it/s]


loss : 0.1685



Epoch 45/60  Mem : 9.95GB  LR : 9.28E-05  Loss: 0.0549: 100%|██████████| 4352/4352 [13:25<00:00,  5.40it/s]
Val Loss: 0.1617: 100%|██████████| 2503/2503 [04:14<00:00,  9.83it/s]


loss : 0.1617



Epoch 46/60  Mem : 9.95GB  LR : 8.28E-05  Loss: 0.0546: 100%|██████████| 4352/4352 [13:28<00:00,  5.38it/s]
Val Loss: 0.1668: 100%|██████████| 2503/2503 [04:17<00:00,  9.74it/s]


loss : 0.1668



Epoch 47/60  Mem : 9.95GB  LR : 7.33E-05  Loss: 0.0541: 100%|██████████| 4352/4352 [12:27<00:00,  5.83it/s]
Val Loss: 0.1788: 100%|██████████| 2503/2503 [04:07<00:00, 10.10it/s]


loss : 0.1788



Epoch 48/60  Mem : 9.95GB  LR : 6.43E-05  Loss: 0.0541: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1718: 100%|██████████| 2503/2503 [04:07<00:00, 10.09it/s]


loss : 0.1718



Epoch 49/60  Mem : 9.95GB  LR : 5.58E-05  Loss: 0.0541: 100%|██████████| 4352/4352 [12:11<00:00,  5.95it/s]
Val Loss: 0.1732: 100%|██████████| 2503/2503 [04:07<00:00, 10.10it/s]


loss : 0.1732



Epoch 50/60  Mem : 9.95GB  LR : 4.78E-05  Loss: 0.0539: 100%|██████████| 4352/4352 [12:05<00:00,  6.00it/s]
Val Loss: 0.1755: 100%|██████████| 2503/2503 [04:08<00:00, 10.09it/s]


loss : 0.1755



Epoch 51/60  Mem : 9.95GB  LR : 4.04E-05  Loss: 0.0538: 100%|██████████| 4352/4352 [12:28<00:00,  5.81it/s]
Val Loss: 0.1776: 100%|██████████| 2503/2503 [04:08<00:00, 10.09it/s]


loss : 0.1776



Epoch 52/60  Mem : 9.95GB  LR : 3.36E-05  Loss: 0.0536: 100%|██████████| 4352/4352 [12:16<00:00,  5.91it/s]
Val Loss: 0.1832: 100%|██████████| 2503/2503 [04:20<00:00,  9.61it/s]


loss : 0.1832



Epoch 53/60  Mem : 9.95GB  LR : 2.73E-05  Loss: 0.0536: 100%|██████████| 4352/4352 [13:28<00:00,  5.38it/s]
Val Loss: 0.1747: 100%|██████████| 2503/2503 [04:13<00:00,  9.86it/s]


loss : 0.1747



Epoch 54/60  Mem : 9.95GB  LR : 2.17E-05  Loss: 0.0533: 100%|██████████| 4352/4352 [13:22<00:00,  5.43it/s]
Val Loss: 0.1799: 100%|██████████| 2503/2503 [04:14<00:00,  9.85it/s]


loss : 0.1799



Epoch 55/60  Mem : 9.95GB  LR : 1.67E-05  Loss: 0.0532: 100%|██████████| 4352/4352 [13:39<00:00,  5.31it/s]
Val Loss: 0.1793: 100%|██████████| 2503/2503 [04:26<00:00,  9.38it/s]


loss : 0.1793



Epoch 56/60  Mem : 9.95GB  LR : 1.23E-05  Loss: 0.0533: 100%|██████████| 4352/4352 [13:49<00:00,  5.25it/s]
Val Loss: 0.1764: 100%|██████████| 2503/2503 [04:26<00:00,  9.40it/s]


loss : 0.1764



Epoch 57/60  Mem : 9.95GB  LR : 8.62E-06  Loss: 0.0530: 100%|██████████| 4352/4352 [13:19<00:00,  5.45it/s]
Val Loss: 0.1815: 100%|██████████| 2503/2503 [04:06<00:00, 10.14it/s]


loss : 0.1815



Epoch 58/60  Mem : 9.95GB  LR : 5.56E-06  Loss: 0.0530: 100%|██████████| 4352/4352 [13:10<00:00,  5.50it/s]
Val Loss: 0.1788: 100%|██████████| 2503/2503 [04:06<00:00, 10.13it/s]


loss : 0.1788



Epoch 59/60  Mem : 9.95GB  LR : 3.18E-06  Loss: 0.0530: 100%|██████████| 4352/4352 [13:11<00:00,  5.50it/s]
Val Loss: 0.1761: 100%|██████████| 2503/2503 [04:07<00:00, 10.13it/s]


loss : 0.1761



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▇█▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▆▂▂▂▁▂▂▃▃▁▂▂▃▄▁▂▂▃▅▂▃▁▂▁▂▃▃▃▂▂▂▃▃▃▄▄▄▄▃

0,1
epoch,59.0
train_loss,0.05297
valid_loss,0.17611


In [4]:
# valid
fold_dict = json.load(open("/kaggle/src/dataset/fold.json", "r"))
pth_type = "best"

for fold in [2]:
    kidney = fold_dict[f"fold{fold}"]["valid"][0]
    pth_path = f"./{cfg.exp_name}/{cfg.exp_name}_{pth_type}_fold{fold}.pth"
    stack_path = f"/kaggle/working/dataset/stack_train01/{kidney}_images.npy"
    label_path = f"/kaggle/working/dataset/stack_train01/{kidney}_labels.npy"
    preds_path = f"./preds/{kidney}_{pth_type}_preds.npy"

    model = load_inference_model(pth_path, cfg)

    print(kidney)
    print(stack_path)
    # inference(model, stack_path, preds_path, cfg)

    label = np.load(label_path)
    preds = np.load(preds_path)
    thresh_score_dict = {}

    for thresh in np.arange(0.1, 0.99, 0.1):
        thresh = round(thresh, 5)
        thresh_score_dict[thresh] = compute_surface_dice_score_from_volume(ne.evaluate("preds > thresh"), label)

    max_score_thresh = max(thresh_score_dict, key=thresh_score_dict.get)

    for thresh in np.arange(max_score_thresh - 0.1, max_score_thresh + 0.1, 0.01):
        thresh = round(thresh, 5)
        thresh_score_dict[thresh] = compute_surface_dice_score_from_volume(ne.evaluate("preds > thresh"), label)

    print(max(thresh_score_dict.items(), key=lambda x: x[1]))

model_name Unet
backbone se_resnext50_32x4d
fold_2
kidney_2_sparse
/kaggle/working/dataset/stack_clipped/kidney_2_sparse_images.npy
(0.67, 0.8495)


In [3]:
# infer
pth_type = "best"
fold = 2
kidneys = ["kidney_2_sparse", "kidney_3_sparse", "kidney_9_pseudo"]

pth_path = f"./{cfg.exp_name}/{cfg.exp_name}_{pth_type}_fold{fold}.pth"
for kidney in kidneys:
    stack_path = f"/kaggle/working/dataset/stack_train01/{kidney}_images.npy"
    preds_path = f"./preds/{kidney}_{pth_type}_preds.npy"

    model = load_inference_model(pth_path, cfg)

    print(stack_path)
    inference(model, stack_path, preds_path, cfg)

model_name Unet
backbone se_resnext50_32x4d
/kaggle/working/dataset/stack_clipped/kidney_2_sparse_images.npy
././preds/kidney_2_sparse_preds_0.npy


100%|██████████| 2216/2216 [10:13<00:00,  3.61it/s]


././preds/kidney_2_sparse_preds_1.npy


100%|██████████| 1040/1040 [09:36<00:00,  1.80it/s]


././preds/kidney_2_sparse_preds_2.npy


100%|██████████| 1510/1510 [10:44<00:00,  2.34it/s]


model_name Unet
backbone se_resnext50_32x4d
/kaggle/working/dataset/stack_clipped/kidney_3_sparse_images.npy
././preds/kidney_3_sparse_preds_0.npy


100%|██████████| 495/495 [03:35<00:00,  2.30it/s]


././preds/kidney_3_sparse_preds_1.npy


100%|██████████| 1705/1705 [03:35<00:00,  7.93it/s]


././preds/kidney_3_sparse_preds_2.npy


100%|██████████| 1509/1509 [03:39<00:00,  6.87it/s]


model_name Unet
backbone se_resnext50_32x4d
/kaggle/working/dataset/stack_clipped/kidney_9_pseudo_images.npy
././preds/kidney_9_pseudo_preds_0.npy


100%|██████████| 2140/2140 [10:50<00:00,  3.29it/s]


././preds/kidney_9_pseudo_preds_1.npy


100%|██████████| 1107/1107 [10:19<00:00,  1.79it/s]


././preds/kidney_9_pseudo_preds_2.npy


100%|██████████| 1643/1643 [10:39<00:00,  2.57it/s]


In [15]:
pth_type = "best"
kidneys = ["kidney_2_sparse", "kidney_3_sparse", "kidney_9_pseudo"]
for kidney in kidneys:
    stack_path = f"/kaggle/working/dataset/stack_train01/{kidney}_images.npy"
    preds_path = f"./preds/{kidney}_{pth_type}_preds.npy"
    preds_th_path = f"./preds/{kidney}_{pth_type}_preds_th.npy"
    preds = np.load(preds_path)
    preds_th = (preds > 0.67).astype(np.bool_)
    np.save(preds_th_path, preds > 0.67)