In [1]:
import json
import pandas as pd
import numpy as np
import numexpr as ne
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
import wandb
from dotenv import load_dotenv
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.cuda.amp as amp

from src.model.model import save_model, load_inference_model
from src.dataset.df import df_dataset, check_dataset
from src.utils.common import set_seed

from src.experiment.inference import inference
from src.experiment.initialize import init_dataset, init_model, init_exp
from src.utils.metrics import compute_surface_dice_score_from_volume
from src.model.model import build_model
from src.model.loss import get_lossfn
from src.utils.metrics import get_metrics
from src.model.scheduler import get_scheduler
from src.dataset.common import get_train_dataset

In [2]:
class cfg:
    debug = False
    check_dataset = False

    # = data CFG ====================================================

    dataset_path = "/kaggle/working/dataset/train01_xy_256_128_z_1_2/"
    train_dataset = "Base2dDataset"
    negative_sample_rate = 0.2

    # = experiment CFG =================================================

    project = "SenNet"
    exp_name = os.path.basename(os.getcwd())
    notes = "baseline_0114 mini scse"

    # = model CFG ======================================================

    model_arch = "Unet"
    backbone = "se_resnext50_32x4d"
    in_chans = 1
    target_size = 1
    decoder_attention_type = "scse"
    # = training CFG ===================================================

    epochs = 20

    train_batch_size = 64
    valid_batch_size = train_batch_size

    loss = "DiceLoss"
    metrics = "Dice"
    lr = 5e-4
    num_workers = 12

    # = augmentation ===================================================

    image_size = 256
    train_aug = [
        A.RandomRotate90(p=0.5),
        A.RandomGamma(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.RandomResizedCrop(image_size, image_size, scale=(0.8, 1)),
        A.ShiftScaleRotate(p=0.5),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug = [
        ToTensorV2(transpose_mask=True),
    ]

    # =============== inference ========================================

    test_dataset = "BaseInferenceDataset"
    stride = image_size // 2
    drop_egde_pixel = 32


load_dotenv("/kaggle/key.env")
set_seed()

In [3]:
def filter_dataset(df):
    # trainのうちlabelが全くないものは90%の確率で除外
    df["random"] = np.random.rand(len(df))
    df = df[(df["sum"] > 0) | (df["fold0"] == "valid") | (df["random"] < cfg.negative_sample_rate)]
    df = df.reset_index(drop=True)
    df = df.drop(["random"], axis=1)
    return df


def mini_dataset(df):
    df["random"] = np.random.rand(len(df))
    df = df[df["random"] < 0.25]
    df = df.reset_index(drop=True)
    df = df.drop(["random"], axis=1)
    return df


df = pd.read_csv(f"{cfg.dataset_path}/dataset.csv")
df["axis"] = df["image_path"].str.split("/").str[-2].str.split("_").str[-1]
df = df[df["axis"] == "axis0"]
df = filter_dataset(df)
# df = mini_dataset(df)
if cfg.debug:
    df = df.sample(10000).reset_index(drop=True)
display(df)

if cfg.check_dataset:
    check_dataset(df, cfg)

  df = pd.read_csv(f"{cfg.dataset_path}/dataset.csv")


Unnamed: 0,image_path,label_path,fname,kidney,x,y,z,std,sum,fold0,fold1,axis
0,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z0_std0039_sum0,kidney_1_dense,0,0,0,39,0,valid,train,axis0
1,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z100_std0037_sum0,kidney_1_dense,0,0,100,37,0,valid,train,axis0
2,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z102_std0036_sum0,kidney_1_dense,0,0,102,36,0,valid,train,axis0
3,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z104_std0036_sum0,kidney_1_dense,0,0,104,36,0,valid,train,axis0
4,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x0_y0_z106_std0036_sum0,kidney_1_dense,0,0,106,36,0,valid,train,axis0
...,...,...,...,...,...,...,...,...,...,...,...,...
158709,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x256_y896_z990_std0224_sum1278,kidney_3_sparse,256,896,990,224,1278,train,,axis0
158710,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x256_y896_z992_std0226_sum1334,kidney_3_sparse,256,896,992,226,1334,train,,axis0
158711,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x256_y896_z994_std0228_sum1362,kidney_3_sparse,256,896,994,228,1362,train,,axis0
158712,/kaggle/working/dataset/cropped_xy_256_128_z_1...,/kaggle/working/dataset/cropped_xy_256_128_z_1...,x256_y896_z996_std0231_sum1339,kidney_3_sparse,256,896,996,231,1339,train,,axis0


In [4]:
if cfg.debug:
    print("!!!Debug mode!!!\n")
    cfg.epochs = 5

for fold in range(1):
    dataset = get_train_dataset(cfg)

    train_df = df[df[f"fold{fold}"] == "train"]
    valid_df = df[df[f"fold{fold}"] == "valid"]

    train_dataset = dataset(train_df, cfg, is_train=True)
    valid_dataset = dataset(valid_df, cfg, is_train=False)

    train_dataloader = DataLoader(train_dataset, batch_size=cfg.train_batch_size, num_workers=cfg.num_workers, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.valid_batch_size, num_workers=cfg.num_workers, shuffle=False)

    model = build_model(cfg.model_arch, cfg.backbone, cfg.in_chans, cfg.target_size, cfg, decoder_attention_type=cfg.decoder_attention_type)
    scaler = torch.cuda.amp.GradScaler()
    criterion = get_lossfn(cfg)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    scheduler = get_scheduler(cfg, optimizer)
    metrics = get_metrics(cfg)

    # model, scaler, criterion, optimizer, scheduler, metrics = init_model(cfg)
    slacknotify = init_exp(fold, cfg)

    path_best = f"./{cfg.exp_name}/{cfg.exp_name}_best_fold{fold}.pth"
    path_last = f"./{cfg.exp_name}/{cfg.exp_name}_last_fold{fold}.pth"

    best_loss = float("inf")
    for epoch in range(cfg.epochs):
        model.train()
        total_loss = 0.0
        pbar_train = tqdm(enumerate(train_dataloader), total=len(train_dataloader), bar_format="{l_bar}{bar:10}{r_bar}{bar:-0b}")

        for i, (images, masks) in pbar_train:
            images, masks = images.cuda(), masks.cuda()
            optimizer.zero_grad()

            with amp.autocast():
                preds = model(images)
                loss = criterion(preds, masks)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                total_loss += loss.detach().item()

            loss_ = total_loss / (i + 1)
            lr = f"LR : {scheduler.get_lr()[0]:.2E}"
            gpu_mem = f"Mem : {torch.cuda.memory_reserved() / 1E9:.3g}GB"
            pbar_train.set_description(("%10s  " * 3 + "%10s") % (f"Epoch {epoch}/{cfg.epochs}", gpu_mem, lr, f"Loss: {loss_:.4f}"))

        train_loss = loss_
        scheduler.step()
        wandb.log({"epoch": epoch, "train_loss": train_loss})

        model.eval()
        total_loss = 0.0
        pbar_val = tqdm(enumerate(valid_dataloader), total=len(valid_dataloader), bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")

        for i, (images, masks) in pbar_val:
            images, masks = images.cuda(), masks.cuda()
            with torch.no_grad():
                preds = model(images)
                loss = criterion(preds, masks)
                total_loss += loss.item()

            loss_ = total_loss / (i + 1)
            pbar_val.set_description(("%10s") % (f"Val Loss: {loss_:.4f}"))
        valid_loss = loss_
        wandb.log({"epoch": epoch, "valid_loss": valid_loss})

        if valid_loss < best_loss:
            print(f"loss : {valid_loss:.4f}\tSAVED MODEL\n")
            slacknotify.send_reply(f"epoch : {epoch}\tscore : {valid_loss:.4f}\tBEST")
            best_loss = valid_loss
            save_model(model, cfg, path_best, loss=loss)
        else:
            print(f"loss : {valid_loss:.4f}\n")
            slacknotify.send_reply(f"epoch : {epoch}\tscore : {valid_loss:.4f}")

    save_model(model, cfg, path_last, loss=valid_loss)
    wandb.config.update({"last_loss": valid_loss, "best_loss": best_loss})

    slacknotify.send_reply(f"{cfg.exp_name}_fold{fold} training finished\nbest loss : {best_loss:.4f} last loss : {loss_:.4f}", True)

    if wandb.run:
        wandb.finish()

model_arch:  Unet
backbone:  se_resnext50_32x4d


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwelshonionman[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0/20  Mem : 9.88GB  LR : 5.00E-04  Loss: 0.1773: 100%|██████████| 1266/1266 [04:17<00:00,  4.91it/s]
Val Loss: 0.3738: 100%|██████████| 1214/1214 [02:14<00:00,  9.03it/s]


loss : 0.3738	SAVED MODEL



Epoch 1/20  Mem : 13.3GB  LR : 5.00E-03  Loss: 0.1404: 100%|██████████| 1266/1266 [04:14<00:00,  4.97it/s]
Val Loss: 0.3634: 100%|██████████| 1214/1214 [02:10<00:00,  9.29it/s]


loss : 0.3634	SAVED MODEL



Epoch 2/20  Mem : 13.3GB  LR : 5.00E-04  Loss: 0.0900: 100%|██████████| 1266/1266 [04:09<00:00,  5.07it/s]
Val Loss: 0.3597: 100%|██████████| 1214/1214 [02:14<00:00,  9.04it/s]


loss : 0.3597	SAVED MODEL



Epoch 3/20  Mem : 13.3GB  LR : 4.97E-04  Loss: 0.0837: 100%|██████████| 1266/1266 [04:14<00:00,  4.97it/s]
Val Loss: 0.3532: 100%|██████████| 1214/1214 [02:14<00:00,  9.04it/s]


loss : 0.3532	SAVED MODEL



Epoch 4/20  Mem : 13.3GB  LR : 4.88E-04  Loss: 0.0808: 100%|██████████| 1266/1266 [04:14<00:00,  4.98it/s]
Val Loss: 0.3242: 100%|██████████| 1214/1214 [02:13<00:00,  9.12it/s]


loss : 0.3242	SAVED MODEL



Epoch 5/20  Mem : 13.3GB  LR : 4.73E-04  Loss: 0.0792: 100%|██████████| 1266/1266 [04:14<00:00,  4.98it/s]
Val Loss: 0.3324: 100%|██████████| 1214/1214 [02:13<00:00,  9.12it/s]


loss : 0.3324



Epoch 6/20  Mem : 13.3GB  LR : 4.52E-04  Loss: 0.0783: 100%|██████████| 1266/1266 [04:11<00:00,  5.04it/s]
Val Loss: 0.3232: 100%|██████████| 1214/1214 [02:13<00:00,  9.13it/s]


loss : 0.3232	SAVED MODEL



Epoch 7/20  Mem : 13.3GB  LR : 4.27E-04  Loss: 0.0761: 100%|██████████| 1266/1266 [04:13<00:00,  5.00it/s]
Val Loss: 0.2988: 100%|██████████| 1214/1214 [02:12<00:00,  9.17it/s]


loss : 0.2988	SAVED MODEL



Epoch 8/20  Mem : 13.3GB  LR : 3.97E-04  Loss: 0.0753: 100%|██████████| 1266/1266 [04:12<00:00,  5.01it/s]
Val Loss: 0.3106: 100%|██████████| 1214/1214 [02:12<00:00,  9.16it/s]


loss : 0.3106



Epoch 9/20  Mem : 13.3GB  LR : 3.64E-04  Loss: 0.0740: 100%|██████████| 1266/1266 [04:12<00:00,  5.02it/s]
Val Loss: 0.3102: 100%|██████████| 1214/1214 [02:13<00:00,  9.09it/s]


loss : 0.3102



Epoch 10/20  Mem : 13.3GB  LR : 3.27E-04  Loss: 0.0737: 100%|██████████| 1266/1266 [04:14<00:00,  4.97it/s]
Val Loss: 0.2992: 100%|██████████| 1214/1214 [02:12<00:00,  9.17it/s]


loss : 0.2992



Epoch 11/20  Mem : 13.3GB  LR : 2.89E-04  Loss: 0.0715: 100%|██████████| 1266/1266 [04:10<00:00,  5.05it/s]
Val Loss: 0.3048: 100%|██████████| 1214/1214 [02:12<00:00,  9.19it/s]


loss : 0.3048



Epoch 12/20  Mem : 13.3GB  LR : 2.50E-04  Loss: 0.0720: 100%|██████████| 1266/1266 [04:11<00:00,  5.03it/s]
Val Loss: 0.2946: 100%|██████████| 1214/1214 [02:14<00:00,  9.05it/s]


loss : 0.2946	SAVED MODEL



Epoch 13/20  Mem : 13.3GB  LR : 2.11E-04  Loss: 0.0712: 100%|██████████| 1266/1266 [04:11<00:00,  5.04it/s]
Val Loss: 0.2867: 100%|██████████| 1214/1214 [02:14<00:00,  9.00it/s]


loss : 0.2867	SAVED MODEL



Epoch 14/20  Mem : 13.3GB  LR : 1.73E-04  Loss: 0.0711:  68%|██████▊   | 859/1266 [02:54<01:21,  4.99it/s]wandb: Network error (ConnectionError), entering retry loop.
Epoch 14/20  Mem : 13.3GB  LR : 1.73E-04  Loss: 0.0710: 100%|██████████| 1266/1266 [04:16<00:00,  4.94it/s]
Val Loss: 0.2959: 100%|██████████| 1214/1214 [02:14<00:00,  9.00it/s]


loss : 0.2959



URLError: <urlopen error [Errno -3] Temporary failure in name resolution>

In [4]:
fold_dict = json.load(open("/kaggle/src/dataset/fold.json", "r"))

for fold in range(1):
    print(f"fold_{fold}")
    valid_kidney = fold_dict[f"fold{fold}"]["valid"][0]
    path_best = f"./{cfg.exp_name}/{cfg.exp_name}_best_fold{fold}.pth"
    model = load_inference_model(path_best, cfg)

    stack_path = f"/kaggle/working/dataset/stack_train01/{valid_kidney}_images.npy"
    label_path = f"/kaggle/working/dataset/stack_train01/{valid_kidney}_labels.npy"
    save_path = "./preds"
    preds_path = f"{save_path}/{valid_kidney}_preds.npy"
    inference(model, stack_path, save_path, cfg)

    label = np.load(label_path)
    preds = np.load(preds_path)
    thresh_score_dict = {}

    for thresh in np.arange(0.1, 0.99, 0.1):
        thresh = round(thresh, 5)
        thresh_score_dict[thresh] = compute_surface_dice_score_from_volume(ne.evaluate("preds > thresh"), label)

    max_score_thresh = max(thresh_score_dict, key=thresh_score_dict.get)

    for thresh in np.arange(max_score_thresh - 0.1, max_score_thresh + 0.1, 0.01):
        thresh = round(thresh, 5)
        thresh_score_dict[thresh] = compute_surface_dice_score_from_volume(ne.evaluate("preds > thresh"), label)

    print(max(thresh_score_dict.items(), key=lambda x: x[1]))

fold_0
model_name Unet
backbone se_resnext50_32x4d


RuntimeError: Error(s) in loading state_dict for CustomInferenceModel:
	Unexpected key(s) in state_dict: "model.decoder.blocks.0.attention1.attention.cSE.1.weight", "model.decoder.blocks.0.attention1.attention.cSE.1.bias", "model.decoder.blocks.0.attention1.attention.cSE.3.weight", "model.decoder.blocks.0.attention1.attention.cSE.3.bias", "model.decoder.blocks.0.attention1.attention.sSE.0.weight", "model.decoder.blocks.0.attention1.attention.sSE.0.bias", "model.decoder.blocks.0.attention2.attention.cSE.1.weight", "model.decoder.blocks.0.attention2.attention.cSE.1.bias", "model.decoder.blocks.0.attention2.attention.cSE.3.weight", "model.decoder.blocks.0.attention2.attention.cSE.3.bias", "model.decoder.blocks.0.attention2.attention.sSE.0.weight", "model.decoder.blocks.0.attention2.attention.sSE.0.bias", "model.decoder.blocks.1.attention1.attention.cSE.1.weight", "model.decoder.blocks.1.attention1.attention.cSE.1.bias", "model.decoder.blocks.1.attention1.attention.cSE.3.weight", "model.decoder.blocks.1.attention1.attention.cSE.3.bias", "model.decoder.blocks.1.attention1.attention.sSE.0.weight", "model.decoder.blocks.1.attention1.attention.sSE.0.bias", "model.decoder.blocks.1.attention2.attention.cSE.1.weight", "model.decoder.blocks.1.attention2.attention.cSE.1.bias", "model.decoder.blocks.1.attention2.attention.cSE.3.weight", "model.decoder.blocks.1.attention2.attention.cSE.3.bias", "model.decoder.blocks.1.attention2.attention.sSE.0.weight", "model.decoder.blocks.1.attention2.attention.sSE.0.bias", "model.decoder.blocks.2.attention1.attention.cSE.1.weight", "model.decoder.blocks.2.attention1.attention.cSE.1.bias", "model.decoder.blocks.2.attention1.attention.cSE.3.weight", "model.decoder.blocks.2.attention1.attention.cSE.3.bias", "model.decoder.blocks.2.attention1.attention.sSE.0.weight", "model.decoder.blocks.2.attention1.attention.sSE.0.bias", "model.decoder.blocks.2.attention2.attention.cSE.1.weight", "model.decoder.blocks.2.attention2.attention.cSE.1.bias", "model.decoder.blocks.2.attention2.attention.cSE.3.weight", "model.decoder.blocks.2.attention2.attention.cSE.3.bias", "model.decoder.blocks.2.attention2.attention.sSE.0.weight", "model.decoder.blocks.2.attention2.attention.sSE.0.bias", "model.decoder.blocks.3.attention1.attention.cSE.1.weight", "model.decoder.blocks.3.attention1.attention.cSE.1.bias", "model.decoder.blocks.3.attention1.attention.cSE.3.weight", "model.decoder.blocks.3.attention1.attention.cSE.3.bias", "model.decoder.blocks.3.attention1.attention.sSE.0.weight", "model.decoder.blocks.3.attention1.attention.sSE.0.bias", "model.decoder.blocks.3.attention2.attention.cSE.1.weight", "model.decoder.blocks.3.attention2.attention.cSE.1.bias", "model.decoder.blocks.3.attention2.attention.cSE.3.weight", "model.decoder.blocks.3.attention2.attention.cSE.3.bias", "model.decoder.blocks.3.attention2.attention.sSE.0.weight", "model.decoder.blocks.3.attention2.attention.sSE.0.bias", "model.decoder.blocks.4.attention1.attention.cSE.1.weight", "model.decoder.blocks.4.attention1.attention.cSE.1.bias", "model.decoder.blocks.4.attention1.attention.cSE.3.weight", "model.decoder.blocks.4.attention1.attention.cSE.3.bias", "model.decoder.blocks.4.attention1.attention.sSE.0.weight", "model.decoder.blocks.4.attention1.attention.sSE.0.bias", "model.decoder.blocks.4.attention2.attention.cSE.1.weight", "model.decoder.blocks.4.attention2.attention.cSE.1.bias", "model.decoder.blocks.4.attention2.attention.cSE.3.weight", "model.decoder.blocks.4.attention2.attention.cSE.3.bias", "model.decoder.blocks.4.attention2.attention.sSE.0.weight", "model.decoder.blocks.4.attention2.attention.sSE.0.bias". 