In [1]:
import numpy as np
import os

import albumentations as A
from albumentations.pytorch import ToTensorV2
import wandb
from dotenv import load_dotenv

from src.model.model import save_model, load_model
from src.dataset.df import df_dataset, check_dataset
from src.utils.common import set_seed
from src.utils.optim_thresh import calc_optim_thresh
from src.experiment.experiment import train, valid
from src.experiment.initialize import init_dataset, init_model, init_exp

In [2]:
class cfg:
    debug = False
    check_dataset = False

    # = data CFG ====================================================

    dataset_path = "/kaggle/working/dataset/cropped_xy_256_128_z_6_6/"
    dataset = "base2d"

    # = experiment CFG =================================================

    project = "SenNet"
    exp_name = os.path.basename(os.getcwd())
    notes = "aug_blurs"

    # = model CFG ======================================================

    model_arch = "Unet"
    backbone = "efficientnet-b0"
    in_chans = 6
    target_size = 6

    # = training CFG ===================================================

    epochs = 20

    train_batch_size = 128
    valid_batch_size = train_batch_size

    loss = "DiceLoss"
    metrics = "Dice"
    lr = 1e-4
    thresholds_to_test = range(2, 101, 4)
    num_workers = 24

    # = augmentation ===================================================

    image_size = 256
    train_aug = [
        # A.RandomResizedCrop(image_size, image_size, scale=(0.8, 1.0)),
        # A.ShiftScaleRotate(p=0.75),
        A.OneOf(
            [
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
            ],
            p=0.4,
        ),
        # A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug = [
        ToTensorV2(transpose_mask=True),
    ]


load_dotenv("/kaggle/key.env")
set_seed()

In [3]:
def filter_dataset(df):
    # labelが全くないものは90%の確率で除外
    df["random"] = np.random.rand(len(df))
    df = df[(df["sum"] > 0) | (df["random"] < 0.1)]
    df = df.reset_index(drop=True)
    df = df.drop(["random"], axis=1)
    return df


df = df_dataset(cfg)
df = filter_dataset(df)
# df = df.sample(1000).reset_index(drop=True)
if cfg.debug:
    df = df.sample(10000).reset_index(drop=True)
display(df)

if cfg.check_dataset:
    check_dataset(df, cfg)

Unnamed: 0,image_path,label_path,fname,kidney,x,y,z,std,sum,fold0,fold1
0,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x0_y0_z12_std0040_sum0,kidney_1_dense,0,0,12,40,0,train,valid
1,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x0_y0_z150_std0033_sum0,kidney_1_dense,0,0,150,33,0,train,valid
2,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x0_y0_z252_std0036_sum0,kidney_1_dense,0,0,252,36,0,train,valid
3,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x0_y0_z270_std0036_sum0,kidney_1_dense,0,0,270,36,0,train,valid
4,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x0_y0_z300_std0037_sum0,kidney_1_dense,0,0,300,37,0,train,valid
...,...,...,...,...,...,...,...,...,...,...,...
37926,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x256_y896_z972_std0228_sum3897,kidney_3_sparse,256,896,972,228,3897,,train
37927,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x256_y896_z978_std0223_sum5084,kidney_3_sparse,256,896,978,223,5084,,train
37928,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x256_y896_z984_std0223_sum6730,kidney_3_sparse,256,896,984,223,6730,,train
37929,/kaggle/working/dataset/cropped_xy_256_128_z_6...,/kaggle/working/dataset/cropped_xy_256_128_z_6...,x256_y896_z990_std0227_sum7998,kidney_3_sparse,256,896,990,227,7998,,train


In [4]:
if cfg.debug:
    print("!!!Debug mode!!!\n")
    cfg.epochs = 5

for fold in range(2):
    train_dataloader, valid_dataloader = init_dataset(fold, df, cfg)
    model, scaler, criterion, optimizer, scheduler, metrics = init_model(cfg)
    slacknotify = init_exp(fold, cfg)

    path_best = f"./{cfg.exp_name}/{cfg.exp_name}_best_fold{fold}.pth"
    path_last = f"./{cfg.exp_name}/{cfg.exp_name}_last_fold{fold}.pth"

    best_loss = float("inf")
    for epoch in range(cfg.epochs):
        train(model, train_dataloader, optimizer, criterion, scheduler, scaler, epoch, cfg)
        loss, pred_list, true_list = valid(model, valid_dataloader, criterion, epoch, cfg)

        if loss < best_loss:
            print(f"loss : {loss:.4f}\tSAVED MODEL\n")
            slacknotify.send_reply(f"epoch : {epoch}\tscore : {loss:.4f}\tBEST")
            best_loss = loss
            save_model(model, cfg, path_best, loss=loss)
        else:
            print(f"loss : {loss:.4f}\n")
            slacknotify.send_reply(f"epoch : {epoch}\tscore : {loss:.4f}")

    last_score, last_thresh = calc_optim_thresh(pred_list, true_list, metrics, cfg)
    save_model(model, cfg, path_last, loss=loss, score=last_score, thresh=last_thresh)
    wandb.config.update({"last_score": last_score, "last_thresh": last_thresh})

    best_model = load_model(model, path_best)
    loss, pred_list, true_list = valid(best_model, valid_dataloader, criterion, epoch, cfg, log=False)

    best_score, best_thresh = calc_optim_thresh(pred_list, true_list, metrics, cfg)
    save_model(best_model, cfg, path_best, loss=loss, score=best_score, thresh=best_thresh)
    wandb.config.update({"best_score": best_score, "best_thresh": best_thresh})

    slacknotify.send_reply(
        f"{cfg.exp_name}_fold{fold} training finished\nbest score : {best_score:.4f} last score : {last_score:.4f}",
        True,
    )

    if wandb.run:
        wandb.finish()

model_arch:  Unet
backbone:  efficientnet-b0


[34m[1mwandb[0m: Currently logged in as: [33mwelshonionman[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.16.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/wandb/run-20240101_191858-17ayn0oo[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mexp004_fold0[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/welshonionman/SenNet[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/welshonionman/SenNet/runs/17ayn0oo[0m
Epoch 0/20  Mem : 14.7GB  LR : 1.00E-04  Loss: 0.9302: 100%|██████████| 237/237 [01:58<00:00,  2.00it/s]
Val Loss: 0.9102: 100%|██████████| 30/30 [00:19<00:00,  1.55it/s]                                                                                                                                                                                                                                      


loss : 0.9102	SAVED MODEL



Epoch 1/20  Mem : 19GB  LR : 1.00E-03  Loss: 0.2846: 100%|██████████| 237/237 [01:43<00:00,  2.29it/s]
Val Loss: 0.3333: 100%|██████████| 30/30 [00:18<00:00,  1.60it/s]                                                                                                                                                                                                                                      


loss : 0.3333	SAVED MODEL



Epoch 2/20  Mem : 19GB  LR : 1.00E-04  Loss: 0.1719: 100%|██████████| 237/237 [01:42<00:00,  2.31it/s]
Val Loss: 0.1607: 100%|██████████| 30/30 [00:18<00:00,  1.60it/s]                                                                                                                                                                                                                                      


loss : 0.1607	SAVED MODEL



Epoch 3/20  Mem : 19GB  LR : 9.94E-05  Loss: 0.1694: 100%|██████████| 237/237 [01:42<00:00,  2.32it/s]
Val Loss: 0.1503: 100%|██████████| 30/30 [00:18<00:00,  1.58it/s]                                                                                                                                                                                                                                      


loss : 0.1503	SAVED MODEL



Epoch 4/20  Mem : 19GB  LR : 9.76E-05  Loss: 0.1632: 100%|██████████| 237/237 [01:41<00:00,  2.32it/s]
Val Loss: 0.1466: 100%|██████████| 30/30 [00:18<00:00,  1.61it/s]                                                                                                                                                                                                                                      


loss : 0.1466	SAVED MODEL



Epoch 5/20  Mem : 19GB  LR : 9.46E-05  Loss: 0.1569: 100%|██████████| 237/237 [01:41<00:00,  2.33it/s]
Val Loss: 0.1397: 100%|██████████| 30/30 [00:18<00:00,  1.61it/s]                                                                                                                                                                                                                                      


loss : 0.1397	SAVED MODEL



Epoch 6/20  Mem : 19GB  LR : 9.05E-05  Loss: 0.1507: 100%|██████████| 237/237 [01:41<00:00,  2.33it/s]
Val Loss: 0.1369: 100%|██████████| 30/30 [00:19<00:00,  1.56it/s]                                                                                                                                                                                                                                      


loss : 0.1369	SAVED MODEL



Epoch 7/20  Mem : 19GB  LR : 8.54E-05  Loss: 0.1527: 100%|██████████| 237/237 [01:42<00:00,  2.31it/s]
Val Loss: 0.1407: 100%|██████████| 30/30 [00:18<00:00,  1.62it/s]                                                                                                                                                                                                                                      


loss : 0.1407



Epoch 8/20  Mem : 19GB  LR : 7.94E-05  Loss: 0.1472: 100%|██████████| 237/237 [01:42<00:00,  2.32it/s]
Val Loss: 0.1333: 100%|██████████| 30/30 [00:18<00:00,  1.60it/s]                                                                                                                                                                                                                                      


loss : 0.1333	SAVED MODEL



Epoch 9/20  Mem : 19GB  LR : 7.27E-05  Loss: 0.1498: 100%|██████████| 237/237 [01:41<00:00,  2.33it/s]
Val Loss: 0.1337: 100%|██████████| 30/30 [00:18<00:00,  1.61it/s]                                                                                                                                                                                                                                      


loss : 0.1337



Epoch 10/20  Mem : 19GB  LR : 6.55E-05  Loss: 0.1495: 100%|██████████| 237/237 [01:41<00:00,  2.33it/s]
Val Loss: 0.1299: 100%|██████████| 30/30 [00:18<00:00,  1.58it/s]                                                                                                                                                                                                                                      


loss : 0.1299	SAVED MODEL



Epoch 11/20  Mem : 19GB  LR : 5.79E-05  Loss: 0.1469: 100%|██████████| 237/237 [01:43<00:00,  2.30it/s]
Val Loss: 0.1347: 100%|██████████| 30/30 [00:18<00:00,  1.61it/s]                                                                                                                                                                                                                                      


loss : 0.1347



Epoch 12/20  Mem : 19GB  LR : 5.01E-05  Loss: 0.1454: 100%|██████████| 237/237 [01:42<00:00,  2.31it/s]
Val Loss: 0.1359: 100%|██████████| 30/30 [00:19<00:00,  1.55it/s]                                                                                                                                                                                                                                      


loss : 0.1359



Epoch 13/20  Mem : 19GB  LR : 4.22E-05  Loss: 0.1442: 100%|██████████| 237/237 [01:42<00:00,  2.31it/s]
Val Loss: 0.1273: 100%|██████████| 30/30 [00:18<00:00,  1.61it/s]                                                                                                                                                                                                                                      


loss : 0.1273	SAVED MODEL



Epoch 14/20  Mem : 19GB  LR : 3.46E-05  Loss: 0.1444: 100%|██████████| 237/237 [01:42<00:00,  2.32it/s]
Val Loss: 0.1296: 100%|██████████| 30/30 [00:18<00:00,  1.58it/s]                                                                                                                                                                                                                                      


loss : 0.1296



Epoch 15/20  Mem : 19GB  LR : 2.74E-05  Loss: 0.1440: 100%|██████████| 237/237 [01:41<00:00,  2.33it/s]
Val Loss: 0.1305: 100%|██████████| 30/30 [00:18<00:00,  1.59it/s]                                                                                                                                                                                                                                      


loss : 0.1305



Epoch 16/20  Mem : 19GB  LR : 2.07E-05  Loss: 0.1432: 100%|██████████| 237/237 [01:42<00:00,  2.31it/s]
Val Loss: 0.1325: 100%|██████████| 30/30 [00:18<00:00,  1.62it/s]                                                                                                                                                                                                                                      


loss : 0.1325



Epoch 17/20  Mem : 19GB  LR : 1.47E-05  Loss: 0.1417: 100%|██████████| 237/237 [01:41<00:00,  2.32it/s]
Val Loss: 0.1272: 100%|██████████| 30/30 [00:19<00:00,  1.57it/s]                                                                                                                                                                                                                                      


loss : 0.1272	SAVED MODEL



Epoch 18/20  Mem : 19GB  LR : 9.64E-06  Loss: 0.1382: 100%|██████████| 237/237 [01:42<00:00,  2.31it/s]
Val Loss: 0.1288: 100%|██████████| 30/30 [00:18<00:00,  1.62it/s]                                                                                                                                                                                                                                      


loss : 0.1288



Epoch 19/20  Mem : 19GB  LR : 5.54E-06  Loss: 0.1421: 100%|██████████| 237/237 [01:42<00:00,  2.32it/s]
Val Loss: 0.1329: 100%|██████████| 30/30 [00:18<00:00,  1.62it/s]                                                                                                                                                                                                                                      


loss : 0.1329



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [02:15<00:00,  5.42s/it]
Val Loss: 0.1272: 100%|██████████| 30/30 [00:19<00:00,  1.56it/s]                                                                                                                                                                                                                                      
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [02:17<00:00,  5.49s/it]
[34m[1mwandb[0m:                                                                                
[34m[1mwan

model_arch:  Unet
backbone:  efficientnet-b0


[34m[1mwandb[0m: Tracking run with wandb version 0.16.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/wandb/run-20240101_200603-4ez3u6jx[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mexp004_fold1[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/welshonionman/SenNet[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/welshonionman/SenNet/runs/4ez3u6jx[0m
Epoch 0/20  Mem : 19GB  LR : 1.00E-04  Loss: 0.9603: 100%|██████████| 189/189 [01:25<00:00,  2.22it/s]
Val Loss: 0.9308: 100%|██████████| 109/109 [00:56<00:00,  1.93it/s]                                                                                                                                                                                                                                    


loss : 0.9308	SAVED MODEL



Epoch 1/20  Mem : 19GB  LR : 1.00E-03  Loss: 0.3029: 100%|██████████| 189/189 [01:23<00:00,  2.26it/s]
Val Loss: 0.9094: 100%|██████████| 109/109 [01:05<00:00,  1.67it/s]                                                                                                                                                                                                                                    


loss : 0.9094	SAVED MODEL



Epoch 2/20  Mem : 19GB  LR : 1.00E-04  Loss: 0.1958: 100%|██████████| 189/189 [01:23<00:00,  2.26it/s]
Val Loss: 0.4022: 100%|██████████| 109/109 [00:59<00:00,  1.84it/s]                                                                                                                                                                                                                                    


loss : 0.4022	SAVED MODEL



Epoch 3/20  Mem : 19GB  LR : 9.94E-05  Loss: 0.1859: 100%|██████████| 189/189 [01:24<00:00,  2.22it/s]
Val Loss: 0.3621: 100%|██████████| 109/109 [00:58<00:00,  1.87it/s]                                                                                                                                                                                                                                    


loss : 0.3621	SAVED MODEL



Epoch 4/20  Mem : 19GB  LR : 9.76E-05  Loss: 0.1714: 100%|██████████| 189/189 [01:24<00:00,  2.25it/s]
Val Loss: 0.3610: 100%|██████████| 109/109 [00:58<00:00,  1.87it/s]                                                                                                                                                                                                                                    


loss : 0.3610	SAVED MODEL



Epoch 5/20  Mem : 19GB  LR : 9.46E-05  Loss: 0.1634: 100%|██████████| 189/189 [01:24<00:00,  2.23it/s]
Val Loss: 0.3528: 100%|██████████| 109/109 [00:57<00:00,  1.88it/s]                                                                                                                                                                                                                                    


loss : 0.3528	SAVED MODEL



Epoch 6/20  Mem : 19GB  LR : 9.05E-05  Loss: 0.1512: 100%|██████████| 189/189 [01:23<00:00,  2.27it/s]
Val Loss: 0.3266: 100%|██████████| 109/109 [00:57<00:00,  1.91it/s]                                                                                                                                                                                                                                    


loss : 0.3266	SAVED MODEL



Epoch 7/20  Mem : 19GB  LR : 8.54E-05  Loss: 0.1467: 100%|██████████| 189/189 [01:25<00:00,  2.21it/s]
Val Loss: 0.3179: 100%|██████████| 109/109 [00:57<00:00,  1.90it/s]                                                                                                                                                                                                                                    


loss : 0.3179	SAVED MODEL



Epoch 8/20  Mem : 19GB  LR : 7.94E-05  Loss: 0.1502: 100%|██████████| 189/189 [01:24<00:00,  2.25it/s]
Val Loss: 0.2887: 100%|██████████| 109/109 [00:58<00:00,  1.87it/s]                                                                                                                                                                                                                                    


loss : 0.2887	SAVED MODEL



Epoch 9/20  Mem : 19GB  LR : 7.27E-05  Loss: 0.1439: 100%|██████████| 189/189 [01:24<00:00,  2.24it/s]
Val Loss: 0.2887: 100%|██████████| 109/109 [00:58<00:00,  1.86it/s]                                                                                                                                                                                                                                    


loss : 0.2887



Epoch 10/20  Mem : 19GB  LR : 6.55E-05  Loss: 0.1426: 100%|██████████| 189/189 [01:25<00:00,  2.22it/s]
Val Loss: 0.2730: 100%|██████████| 109/109 [00:58<00:00,  1.86it/s]                                                                                                                                                                                                                                    


loss : 0.2730	SAVED MODEL



Epoch 11/20  Mem : 19GB  LR : 5.79E-05  Loss: 0.1421: 100%|██████████| 189/189 [01:25<00:00,  2.21it/s]
Val Loss: 0.2766: 100%|██████████| 109/109 [00:57<00:00,  1.90it/s]                                                                                                                                                                                                                                    


loss : 0.2766



Epoch 12/20  Mem : 19GB  LR : 5.01E-05  Loss: 0.1385: 100%|██████████| 189/189 [01:23<00:00,  2.26it/s]
Val Loss: 0.2784: 100%|██████████| 109/109 [00:59<00:00,  1.84it/s]                                                                                                                                                                                                                                    


loss : 0.2784



Epoch 13/20  Mem : 19GB  LR : 4.22E-05  Loss: 0.1384: 100%|██████████| 189/189 [01:23<00:00,  2.25it/s]
Val Loss: 0.2915: 100%|██████████| 109/109 [00:57<00:00,  1.89it/s]                                                                                                                                                                                                                                    


loss : 0.2915



Epoch 14/20  Mem : 19GB  LR : 3.46E-05  Loss: 0.1411: 100%|██████████| 189/189 [01:32<00:00,  2.04it/s]
Val Loss: 0.2763: 100%|██████████| 109/109 [00:58<00:00,  1.88it/s]                                                                                                                                                                                                                                    


loss : 0.2763



Epoch 15/20  Mem : 19GB  LR : 2.74E-05  Loss: 0.1431: 100%|██████████| 189/189 [01:24<00:00,  2.25it/s]
Val Loss: 0.2823: 100%|██████████| 109/109 [00:57<00:00,  1.91it/s]                                                                                                                                                                                                                                    


loss : 0.2823



Epoch 16/20  Mem : 19GB  LR : 2.07E-05  Loss: 0.1376: 100%|██████████| 189/189 [01:25<00:00,  2.22it/s]
Val Loss: 0.3007: 100%|██████████| 109/109 [00:57<00:00,  1.91it/s]                                                                                                                                                                                                                                    


loss : 0.3007



Epoch 17/20  Mem : 19GB  LR : 1.47E-05  Loss: 0.1370: 100%|██████████| 189/189 [01:24<00:00,  2.24it/s]
Val Loss: 0.2884: 100%|██████████| 109/109 [00:57<00:00,  1.88it/s]                                                                                                                                                                                                                                    


loss : 0.2884



Epoch 18/20  Mem : 19GB  LR : 9.64E-06  Loss: 0.1366: 100%|██████████| 189/189 [01:25<00:00,  2.22it/s]
Val Loss: 0.2798: 100%|██████████| 109/109 [00:57<00:00,  1.88it/s]                                                                                                                                                                                                                                    


loss : 0.2798



Epoch 19/20  Mem : 19GB  LR : 5.54E-06  Loss: 0.1332: 100%|██████████| 189/189 [01:24<00:00,  2.24it/s]
Val Loss: 0.2833: 100%|██████████| 109/109 [00:58<00:00,  1.87it/s]                                                                                                                                                                                                                                    


loss : 0.2833



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [08:12<00:00, 19.70s/it]
Val Loss: 0.2730: 100%|██████████| 109/109 [00:59<00:00,  1.82it/s]                                                                                                                                                                                                                                    
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [08:12<00:00, 19.70s/it]
[34m[1mwandb[0m:                                                                                
[34m[1mwan