In [1]:
import gc
import os
import sys
import warnings
from glob import glob
import matplotlib.pyplot as plt

import albumentations as A
from albumentations.pytorch import ToTensorV2
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from torch.cuda import amp
from tqdm import tqdm

In [2]:
sys.path.append('../')
from script.metrics import *
from script.dataset import *
from script.helper import *
from script.scheduler import *
from script.loss import *

### Config

In [4]:
class CFG:
    debug = False
    # ============== comp exp name =============
    comp_name = 'contrail'
    comp_dir_path = '/kaggle/input/'
    comp_folder_name = 'google-research-identify-contrails-reduce-global-warming'

    dataset_path = "/kaggle/working/dataset_train/ash_color/"

    exp_name = os.getcwd().split('/')[-1]

    # ============== model cfg =============
    model_arch = 'Unet'
    backbone = 'timm-resnest200e'
    in_chans = 3
    target_size = 1

    # ============== training cfg =============
    train_batch_size = 32
    valid_batch_size = train_batch_size

    epochs = 90
    epochs_patience = 12

    lr = 1e-4
    loss = "DiceLoss"
    smooth = 300
    
    # ============== fixed =============
    num_workers = 4
    seed = 42

    # ============== augmentation =============
    train_aug_list = [
        A.RandomRotate90(),
        A.RandomBrightnessContrast(),
        A.ShiftScaleRotate(p=0.5,rotate_limit=(10,10)),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=10, max_width=int(256 * 0.05), max_height=int(256 * 0.05),
                        mask_fill_value=0, p=0.5),
        A.RandomGridShuffle(),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]


warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(CFG.seed)
os.makedirs(f'./{CFG.exp_name}/', exist_ok=True)

# Dataset

In [4]:
train_df = pd.read_csv(f"{CFG.dataset_path}/train_df.csv")
valid_df = pd.read_csv(f"{CFG.dataset_path}/validation_df.csv")

if CFG.debug:
    train_df=train_df[:2000]
    valid_df=valid_df[:2000]
train_df.shape, valid_df.shape

((20529, 2), (1856, 2))

In [5]:
dataset_train = ContrailsDataset(train_df, CFG.train_aug_list)
dataset_valid = ContrailsDataset(valid_df, CFG.valid_aug_list)

dataloader_train = DataLoader(dataset_train, batch_size=CFG.train_batch_size , shuffle=True, num_workers = CFG.num_workers)
dataloader_valid = DataLoader(dataset_valid, batch_size=CFG.valid_batch_size, num_workers = CFG.num_workers)

print(f"""
{len(dataset_train) = }
train_image_shape : {dataset_train[0][0].shape}
train_mask_shape  : {dataset_train[0][1].shape}
train_image_dtype : {dataset_train[0][0].dtype}
train_mask_dtype : {dataset_train[0][1].dtype}

{len(dataset_valid) = }
valid_image_shape : {dataset_valid[0][0].shape}
valid_mask_shape  : {dataset_valid[0][1].shape}
valid_image_dtype : {dataset_valid[0][0].dtype}
valid_mask_dtype : {dataset_valid[0][1].dtype}
""")

# show_dataset(112, dataset_train)


len(dataset_train) = 20529
train_image_shape : torch.Size([3, 256, 256])
train_mask_shape  : torch.Size([1, 256, 256])
train_image_dtype : torch.float32
train_mask_dtype : torch.float32

len(dataset_valid) = 1856
valid_image_shape : torch.Size([3, 256, 256])
valid_mask_shape  : torch.Size([1, 256, 256])
valid_image_dtype : torch.float32
valid_mask_dtype : torch.float32



# Model

In [6]:
class CustomModel(nn.Module):
    def __init__(self,model_arch, backbone, in_chans, target_size, weight):
        super().__init__()

        self.model = smp.create_model(
            model_arch,
            encoder_name=backbone,
            encoder_weights=weight,
            in_channels=in_chans,
            classes=target_size,
            activation=None,
        )
        
    def forward(self, image):
        output = self.model(image)
        return output


def build_model(model_arch, backbone, in_chans, target_size, weight="imagenet"):
    print('model_arch: ', model_arch)
    print('backbone: ', backbone)
    model = CustomModel(model_arch, backbone, in_chans, target_size, weight)
    return model


num_gpus = torch.cuda.device_count()
device_ids = list(range(num_gpus))

model = build_model(CFG.model_arch, CFG.backbone, CFG.in_chans, CFG.target_size)
model = nn.DataParallel(model, device_ids=device_ids)
model.to(device);

model_arch:  Unet
backbone:  timm-resnest200e


In [7]:
scaler = amp.GradScaler()
criterion = get_lossfn(CFG, CFG.smooth)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs, eta_min=1e-6)

thresholds_to_test = [round(x * 0.01, 2) for x in range(1, 101, 2)]

# Training

In [8]:
if CFG.debug:
    print("!!!Debug mode!!!\n")

dice_score=0
for epoch in range(100):
    model.train()
    
    pbar_train = enumerate(dataloader_train)
    pbar_train = tqdm(pbar_train, total=len(dataloader_train), bar_format="{l_bar}{bar:10}{r_bar}{bar:-0b}")
    loss_train, loss_val= 0.0, 0.0
    for i, (images, masks) in pbar_train:
        images, masks = images.cuda(), masks.cuda()
        optimizer.zero_grad()
        with amp.autocast():
            preds = model(images)
            loss = criterion(preds, masks)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loss_train += loss.detach().item()
        
        lr = f"LR : {scheduler.get_last_lr()[0]:.2E}"
        gpu_mem = f"Mem : {torch.cuda.memory_reserved() / 1E9:.3g}GB"
        pbar_train.set_description(("%10s  " * 3 + "%10s") % (f"Epoch {epoch}/{CFG.epochs}", gpu_mem, lr,
                                                                f"Loss: {loss_train / (i + 1):.4f}"))

    scheduler.step()
    model.eval()
    
    cum_pred = []
    cum_true = []
    pbar_val = enumerate(dataloader_valid)
    pbar_val = tqdm(pbar_val, total=len(dataloader_valid), bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
    for i, (images, masks) in pbar_val:
        images, masks = images.cuda(), masks.cuda()
        with torch.no_grad():
            preds = model(images)
            loss_val += criterion(preds, masks).item()
            preds = torch.sigmoid(preds)
            cum_pred.append(preds.cpu().detach().numpy())
            cum_true.append(masks.cpu().detach().numpy())

        pbar_val.set_description(("%10s") % (f"Val Loss: {loss_val / (i+1):.4f}"))
    
    cum_pred = torch.flatten(torch.from_numpy(np.concatenate(cum_pred, axis=0)))
    cum_true = torch.flatten(torch.from_numpy(np.concatenate(cum_true, axis=0)))
    
    dice_score_, thresh = calc_optim_thresh(cum_pred, cum_true, thresholds_to_test)
    
    if dice_score_ > dice_score:
        print(f"FBeta : {dice_score_:.4f}\tthresh : {thresh}\tSAVED MODEL\n")
        epoch_best=epoch
        dice_score =dice_score_
        torch.save({'model': model.module.state_dict(), 'dice_score': dice_score, 'thresh': thresh,
                    "model_arch":CFG.model_arch, "backbone":CFG.backbone,"in_chans":CFG.in_chans,"target_size":CFG.target_size,},
                    f'./{CFG.exp_name}/{CFG.exp_name}.pth')
    else:
        print(f"FBeta : {dice_score_:.4f}\tthresh : {thresh}\n")
    
!curl -X POST -H 'Authorization: Bearer '$LINE -F 'message=セルの実行が終わりました！{dice_score}' https://notify-api.line.me/api/notify


Epoch 0/90  Mem : 15GB  LR : 1.00E-04  Loss: 0.8031: 100%|██████████| 642/642 [03:13<00:00,  3.33it/s]
Val Loss: 0.5913: 100%|██████████| 58/58 [00:06<00:00,  8.48it/s]


FBeta : 0.5598	thresh : 0.36	SAVED MODEL



Epoch 1/90  Mem : 15GB  LR : 1.00E-04  Loss: 0.4739: 100%|██████████| 642/642 [03:11<00:00,  3.34it/s]
Val Loss: 0.4432: 100%|██████████| 58/58 [00:06<00:00,  8.34it/s]


FBeta : 0.5910	thresh : 0.46	SAVED MODEL



Epoch 2/90  Mem : 15GB  LR : 9.99E-05  Loss: 0.4419: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.4195: 100%|██████████| 58/58 [00:06<00:00,  8.41it/s]


FBeta : 0.5956	thresh : 0.56	SAVED MODEL



Epoch 3/90  Mem : 15GB  LR : 9.97E-05  Loss: 0.4315: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.4207: 100%|██████████| 58/58 [00:06<00:00,  8.38it/s]


FBeta : 0.5908	thresh : 0.31



Epoch 4/90  Mem : 15GB  LR : 9.95E-05  Loss: 0.4193: 100%|██████████| 642/642 [03:11<00:00,  3.36it/s]
Val Loss: 0.3971: 100%|██████████| 58/58 [00:06<00:00,  8.36it/s]


FBeta : 0.6102	thresh : 0.71	SAVED MODEL



Epoch 5/90  Mem : 15GB  LR : 9.92E-05  Loss: 0.4183: 100%|██████████| 642/642 [03:13<00:00,  3.32it/s]
Val Loss: 0.3909: 100%|██████████| 58/58 [00:07<00:00,  8.14it/s]


FBeta : 0.6125	thresh : 0.21	SAVED MODEL



Epoch 6/90  Mem : 15GB  LR : 9.89E-05  Loss: 0.4128: 100%|██████████| 642/642 [03:18<00:00,  3.24it/s]
Val Loss: 0.3845: 100%|██████████| 58/58 [00:07<00:00,  8.04it/s]


FBeta : 0.6197	thresh : 0.26	SAVED MODEL



Epoch 7/90  Mem : 15GB  LR : 9.85E-05  Loss: 0.4091: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3877: 100%|██████████| 58/58 [00:06<00:00,  8.47it/s]


FBeta : 0.6147	thresh : 0.31



Epoch 8/90  Mem : 15GB  LR : 9.81E-05  Loss: 0.4058: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3908: 100%|██████████| 58/58 [00:07<00:00,  8.06it/s]


FBeta : 0.6152	thresh : 0.16



Epoch 9/90  Mem : 15GB  LR : 9.76E-05  Loss: 0.4059: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3858: 100%|██████████| 58/58 [00:07<00:00,  8.13it/s]


FBeta : 0.6201	thresh : 0.96	SAVED MODEL



Epoch 10/90  Mem : 15GB  LR : 9.70E-05  Loss: 0.4058: 100%|██████████| 642/642 [03:19<00:00,  3.21it/s]
Val Loss: 0.3813: 100%|██████████| 58/58 [00:07<00:00,  8.14it/s]


FBeta : 0.6222	thresh : 0.71	SAVED MODEL



Epoch 11/90  Mem : 15GB  LR : 9.64E-05  Loss: 0.3967: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3920: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6138	thresh : 0.36



Epoch 12/90  Mem : 15GB  LR : 9.57E-05  Loss: 0.3978: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3765: 100%|██████████| 58/58 [00:06<00:00,  8.43it/s]


FBeta : 0.6257	thresh : 0.21	SAVED MODEL



Epoch 13/90  Mem : 15GB  LR : 9.50E-05  Loss: 0.3979: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3757: 100%|██████████| 58/58 [00:06<00:00,  8.37it/s]


FBeta : 0.6258	thresh : 0.76	SAVED MODEL



Epoch 14/90  Mem : 15GB  LR : 9.42E-05  Loss: 0.3942: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3804: 100%|██████████| 58/58 [00:06<00:00,  8.33it/s]


FBeta : 0.6220	thresh : 0.96



Epoch 15/90  Mem : 15GB  LR : 9.34E-05  Loss: 0.3932: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3807: 100%|██████████| 58/58 [00:07<00:00,  8.20it/s]


FBeta : 0.6254	thresh : 0.56



Epoch 16/90  Mem : 15GB  LR : 9.25E-05  Loss: 0.3889: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3721: 100%|██████████| 58/58 [00:06<00:00,  8.42it/s]


FBeta : 0.6291	thresh : 0.71	SAVED MODEL



Epoch 17/90  Mem : 15GB  LR : 9.15E-05  Loss: 0.3918: 100%|██████████| 642/642 [03:19<00:00,  3.22it/s]
Val Loss: 0.3726: 100%|██████████| 58/58 [00:07<00:00,  7.97it/s]


FBeta : 0.6297	thresh : 0.51	SAVED MODEL



Epoch 18/90  Mem : 15GB  LR : 9.05E-05  Loss: 0.3890: 100%|██████████| 642/642 [03:14<00:00,  3.30it/s]
Val Loss: 0.3807: 100%|██████████| 58/58 [00:07<00:00,  8.01it/s]


FBeta : 0.6227	thresh : 0.61



Epoch 19/90  Mem : 15GB  LR : 8.95E-05  Loss: 0.3864: 100%|██████████| 642/642 [03:16<00:00,  3.27it/s]
Val Loss: 0.3712: 100%|██████████| 58/58 [00:07<00:00,  8.24it/s]


FBeta : 0.6303	thresh : 0.21	SAVED MODEL



Epoch 20/90  Mem : 15GB  LR : 8.84E-05  Loss: 0.3867: 100%|██████████| 642/642 [03:14<00:00,  3.30it/s]
Val Loss: 0.3774: 100%|██████████| 58/58 [00:07<00:00,  8.25it/s]


FBeta : 0.6271	thresh : 0.96



Epoch 21/90  Mem : 15GB  LR : 8.73E-05  Loss: 0.3836: 100%|██████████| 642/642 [03:14<00:00,  3.31it/s]
Val Loss: 0.3725: 100%|██████████| 58/58 [00:07<00:00,  8.27it/s]


FBeta : 0.6310	thresh : 0.01	SAVED MODEL



Epoch 22/90  Mem : 15GB  LR : 8.61E-05  Loss: 0.3825: 100%|██████████| 642/642 [03:14<00:00,  3.30it/s]
Val Loss: 0.3694: 100%|██████████| 58/58 [00:06<00:00,  8.34it/s]


FBeta : 0.6331	thresh : 0.06	SAVED MODEL



Epoch 23/90  Mem : 15GB  LR : 8.49E-05  Loss: 0.3810: 100%|██████████| 642/642 [03:11<00:00,  3.36it/s]
Val Loss: 0.3664: 100%|██████████| 58/58 [00:06<00:00,  8.45it/s]


FBeta : 0.6369	thresh : 0.06	SAVED MODEL



Epoch 24/90  Mem : 15GB  LR : 8.36E-05  Loss: 0.3798: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3708: 100%|██████████| 58/58 [00:06<00:00,  8.43it/s]


FBeta : 0.6322	thresh : 0.96



Epoch 25/90  Mem : 15GB  LR : 8.23E-05  Loss: 0.3762: 100%|██████████| 642/642 [03:10<00:00,  3.36it/s]
Val Loss: 0.3631: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6368	thresh : 0.36



Epoch 26/90  Mem : 15GB  LR : 8.10E-05  Loss: 0.3759: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3787: 100%|██████████| 58/58 [00:06<00:00,  8.41it/s]


FBeta : 0.6210	thresh : 0.16



Epoch 27/90  Mem : 15GB  LR : 7.96E-05  Loss: 0.3757: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3677: 100%|██████████| 58/58 [00:06<00:00,  8.38it/s]


FBeta : 0.6339	thresh : 0.06



Epoch 28/90  Mem : 15GB  LR : 7.82E-05  Loss: 0.3736: 100%|██████████| 642/642 [03:11<00:00,  3.34it/s]
Val Loss: 0.3547: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6426	thresh : 0.21	SAVED MODEL



Epoch 29/90  Mem : 15GB  LR : 7.67E-05  Loss: 0.3739: 100%|██████████| 642/642 [03:15<00:00,  3.29it/s]
Val Loss: 0.3588: 100%|██████████| 58/58 [00:07<00:00,  8.21it/s]


FBeta : 0.6382	thresh : 0.81



Epoch 30/90  Mem : 15GB  LR : 7.53E-05  Loss: 0.3724: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3601: 100%|██████████| 58/58 [00:07<00:00,  8.08it/s]


FBeta : 0.6401	thresh : 0.21



Epoch 31/90  Mem : 15GB  LR : 7.37E-05  Loss: 0.3708: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3667: 100%|██████████| 58/58 [00:07<00:00,  7.99it/s]


FBeta : 0.6354	thresh : 0.21



Epoch 32/90  Mem : 15GB  LR : 7.22E-05  Loss: 0.3692: 100%|██████████| 642/642 [03:16<00:00,  3.27it/s]
Val Loss: 0.3595: 100%|██████████| 58/58 [00:06<00:00,  8.42it/s]


FBeta : 0.6411	thresh : 0.01



Epoch 33/90  Mem : 15GB  LR : 7.06E-05  Loss: 0.3706: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3562: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6426	thresh : 0.16	SAVED MODEL



Epoch 34/90  Mem : 15GB  LR : 6.90E-05  Loss: 0.3661: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3618: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6374	thresh : 0.56



Epoch 35/90  Mem : 15GB  LR : 6.74E-05  Loss: 0.3649: 100%|██████████| 642/642 [03:13<00:00,  3.33it/s]
Val Loss: 0.3615: 100%|██████████| 58/58 [00:06<00:00,  8.47it/s]


FBeta : 0.6412	thresh : 0.51



Epoch 36/90  Mem : 15GB  LR : 6.58E-05  Loss: 0.3637: 100%|██████████| 642/642 [03:19<00:00,  3.22it/s]
Val Loss: 0.3616: 100%|██████████| 58/58 [00:07<00:00,  8.13it/s]


FBeta : 0.6395	thresh : 0.06



Epoch 37/90  Mem : 15GB  LR : 6.41E-05  Loss: 0.3622: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3601: 100%|██████████| 58/58 [00:07<00:00,  8.12it/s]


FBeta : 0.6428	thresh : 0.46	SAVED MODEL



Epoch 38/90  Mem : 15GB  LR : 6.25E-05  Loss: 0.3614: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3583: 100%|██████████| 58/58 [00:06<00:00,  8.36it/s]


FBeta : 0.6420	thresh : 0.06



Epoch 39/90  Mem : 15GB  LR : 6.08E-05  Loss: 0.3578: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3603: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6418	thresh : 0.31



Epoch 40/90  Mem : 15GB  LR : 5.91E-05  Loss: 0.3579: 100%|██████████| 642/642 [03:13<00:00,  3.32it/s]
Val Loss: 0.3587: 100%|██████████| 58/58 [00:07<00:00,  8.04it/s]


FBeta : 0.6404	thresh : 0.31



Epoch 41/90  Mem : 15GB  LR : 5.74E-05  Loss: 0.3576: 100%|██████████| 642/642 [03:16<00:00,  3.27it/s]
Val Loss: 0.3588: 100%|██████████| 58/58 [00:06<00:00,  8.30it/s]


FBeta : 0.6422	thresh : 0.06



Epoch 42/90  Mem : 15GB  LR : 5.57E-05  Loss: 0.3559: 100%|██████████| 642/642 [03:13<00:00,  3.31it/s]
Val Loss: 0.3556: 100%|██████████| 58/58 [00:06<00:00,  8.30it/s]


FBeta : 0.6425	thresh : 0.41



Epoch 43/90  Mem : 15GB  LR : 5.40E-05  Loss: 0.3568: 100%|██████████| 642/642 [03:15<00:00,  3.29it/s]
Val Loss: 0.3624: 100%|██████████| 58/58 [00:07<00:00,  8.26it/s]


FBeta : 0.6396	thresh : 0.06



Epoch 44/90  Mem : 15GB  LR : 5.22E-05  Loss: 0.3528: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3636: 100%|██████████| 58/58 [00:07<00:00,  8.22it/s]


FBeta : 0.6383	thresh : 0.06



Epoch 45/90  Mem : 15GB  LR : 5.05E-05  Loss: 0.3529: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3579: 100%|██████████| 58/58 [00:07<00:00,  8.18it/s]


FBeta : 0.6436	thresh : 0.01	SAVED MODEL



Epoch 46/90  Mem : 15GB  LR : 4.88E-05  Loss: 0.3528: 100%|██████████| 642/642 [03:19<00:00,  3.23it/s]
Val Loss: 0.3626: 100%|██████████| 58/58 [00:07<00:00,  8.07it/s]


FBeta : 0.6413	thresh : 0.01



Epoch 47/90  Mem : 15GB  LR : 4.70E-05  Loss: 0.3506: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3619: 100%|██████████| 58/58 [00:07<00:00,  8.08it/s]


FBeta : 0.6405	thresh : 0.21



Epoch 48/90  Mem : 15GB  LR : 4.53E-05  Loss: 0.3496: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3549: 100%|██████████| 58/58 [00:07<00:00,  8.12it/s]


FBeta : 0.6458	thresh : 0.56	SAVED MODEL



Epoch 49/90  Mem : 15GB  LR : 4.36E-05  Loss: 0.3492: 100%|██████████| 642/642 [03:18<00:00,  3.24it/s]
Val Loss: 0.3526: 100%|██████████| 58/58 [00:07<00:00,  8.21it/s]


FBeta : 0.6483	thresh : 0.26	SAVED MODEL



Epoch 50/90  Mem : 15GB  LR : 4.19E-05  Loss: 0.3488: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3559: 100%|██████████| 58/58 [00:07<00:00,  8.07it/s]


FBeta : 0.6454	thresh : 0.11



Epoch 51/90  Mem : 15GB  LR : 4.02E-05  Loss: 0.3470: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3515: 100%|██████████| 58/58 [00:07<00:00,  8.14it/s]


FBeta : 0.6504	thresh : 0.06	SAVED MODEL



Epoch 52/90  Mem : 15GB  LR : 3.85E-05  Loss: 0.3462: 100%|██████████| 642/642 [03:17<00:00,  3.25it/s]
Val Loss: 0.3515: 100%|██████████| 58/58 [00:07<00:00,  8.13it/s]


FBeta : 0.6501	thresh : 0.61



Epoch 53/90  Mem : 15GB  LR : 3.69E-05  Loss: 0.3471: 100%|██████████| 642/642 [03:16<00:00,  3.26it/s]
Val Loss: 0.3611: 100%|██████████| 58/58 [00:06<00:00,  8.41it/s]


FBeta : 0.6407	thresh : 0.16



Epoch 54/90  Mem : 15GB  LR : 3.52E-05  Loss: 0.3453: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3628: 100%|██████████| 58/58 [00:07<00:00,  8.26it/s]


FBeta : 0.6394	thresh : 0.01



Epoch 55/90  Mem : 15GB  LR : 3.36E-05  Loss: 0.3464: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3543: 100%|██████████| 58/58 [00:06<00:00,  8.42it/s]


FBeta : 0.6477	thresh : 0.26



Epoch 56/90  Mem : 15GB  LR : 3.20E-05  Loss: 0.3412: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3513: 100%|██████████| 58/58 [00:06<00:00,  8.42it/s]


FBeta : 0.6489	thresh : 0.11



Epoch 57/90  Mem : 15GB  LR : 3.04E-05  Loss: 0.3413: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3521: 100%|██████████| 58/58 [00:06<00:00,  8.43it/s]


FBeta : 0.6505	thresh : 0.11	SAVED MODEL



Epoch 58/90  Mem : 15GB  LR : 2.88E-05  Loss: 0.3437: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3482: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6533	thresh : 0.01	SAVED MODEL



Epoch 59/90  Mem : 15GB  LR : 2.73E-05  Loss: 0.3399: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3506: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6512	thresh : 0.11



Epoch 60/90  Mem : 15GB  LR : 2.58E-05  Loss: 0.3381: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3556: 100%|██████████| 58/58 [00:06<00:00,  8.43it/s]


FBeta : 0.6479	thresh : 0.01



Epoch 61/90  Mem : 15GB  LR : 2.43E-05  Loss: 0.3385: 100%|██████████| 642/642 [03:11<00:00,  3.34it/s]
Val Loss: 0.3514: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6501	thresh : 0.11



Epoch 62/90  Mem : 15GB  LR : 2.28E-05  Loss: 0.3353: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3506: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6510	thresh : 0.21



Epoch 63/90  Mem : 15GB  LR : 2.14E-05  Loss: 0.3365: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3520: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6495	thresh : 0.26



Epoch 64/90  Mem : 15GB  LR : 2.00E-05  Loss: 0.3363: 100%|██████████| 642/642 [03:13<00:00,  3.32it/s]
Val Loss: 0.3519: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6501	thresh : 0.01



Epoch 65/90  Mem : 15GB  LR : 1.87E-05  Loss: 0.3350: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3491: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6505	thresh : 0.21



Epoch 66/90  Mem : 15GB  LR : 1.74E-05  Loss: 0.3349: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3497: 100%|██████████| 58/58 [00:06<00:00,  8.38it/s]


FBeta : 0.6517	thresh : 0.16



Epoch 67/90  Mem : 15GB  LR : 1.61E-05  Loss: 0.3342: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3511: 100%|██████████| 58/58 [00:06<00:00,  8.35it/s]


FBeta : 0.6508	thresh : 0.11



Epoch 68/90  Mem : 15GB  LR : 1.49E-05  Loss: 0.3345: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3488: 100%|██████████| 58/58 [00:06<00:00,  8.41it/s]


FBeta : 0.6532	thresh : 0.21



Epoch 69/90  Mem : 15GB  LR : 1.37E-05  Loss: 0.3339: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3485: 100%|██████████| 58/58 [00:06<00:00,  8.34it/s]


FBeta : 0.6532	thresh : 0.16



Epoch 70/90  Mem : 15GB  LR : 1.26E-05  Loss: 0.3324: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3494: 100%|██████████| 58/58 [00:06<00:00,  8.36it/s]


FBeta : 0.6506	thresh : 0.31



Epoch 71/90  Mem : 15GB  LR : 1.15E-05  Loss: 0.3333: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3471: 100%|██████████| 58/58 [00:06<00:00,  8.38it/s]


FBeta : 0.6536	thresh : 0.11	SAVED MODEL



Epoch 72/90  Mem : 15GB  LR : 1.05E-05  Loss: 0.3311: 100%|██████████| 642/642 [03:11<00:00,  3.34it/s]
Val Loss: 0.3503: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6507	thresh : 0.16



Epoch 73/90  Mem : 15GB  LR : 9.46E-06  Loss: 0.3313: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3477: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6534	thresh : 0.06



Epoch 74/90  Mem : 15GB  LR : 8.52E-06  Loss: 0.3310: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3482: 100%|██████████| 58/58 [00:06<00:00,  8.37it/s]


FBeta : 0.6523	thresh : 0.21



Epoch 75/90  Mem : 15GB  LR : 7.63E-06  Loss: 0.3292: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3476: 100%|██████████| 58/58 [00:06<00:00,  8.37it/s]


FBeta : 0.6534	thresh : 0.16



Epoch 76/90  Mem : 15GB  LR : 6.79E-06  Loss: 0.3309: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3475: 100%|██████████| 58/58 [00:06<00:00,  8.35it/s]


FBeta : 0.6532	thresh : 0.06



Epoch 77/90  Mem : 15GB  LR : 6.01E-06  Loss: 0.3297: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3468: 100%|██████████| 58/58 [00:07<00:00,  8.15it/s]


FBeta : 0.6537	thresh : 0.06	SAVED MODEL



Epoch 78/90  Mem : 15GB  LR : 5.28E-06  Loss: 0.3297: 100%|██████████| 642/642 [03:19<00:00,  3.22it/s]
Val Loss: 0.3477: 100%|██████████| 58/58 [00:07<00:00,  8.15it/s]


FBeta : 0.6537	thresh : 0.21



Epoch 79/90  Mem : 15GB  LR : 4.60E-06  Loss: 0.3317: 100%|██████████| 642/642 [03:15<00:00,  3.29it/s]
Val Loss: 0.3485: 100%|██████████| 58/58 [00:07<00:00,  8.21it/s]


FBeta : 0.6528	thresh : 0.21



Epoch 80/90  Mem : 15GB  LR : 3.99E-06  Loss: 0.3301: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3491: 100%|██████████| 58/58 [00:07<00:00,  8.12it/s]


FBeta : 0.6529	thresh : 0.16



Epoch 81/90  Mem : 15GB  LR : 3.42E-06  Loss: 0.3282: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3483: 100%|██████████| 58/58 [00:07<00:00,  7.99it/s]


FBeta : 0.6522	thresh : 0.21



Epoch 82/90  Mem : 15GB  LR : 2.92E-06  Loss: 0.3306: 100%|██████████| 642/642 [03:18<00:00,  3.23it/s]
Val Loss: 0.3492: 100%|██████████| 58/58 [00:07<00:00,  8.10it/s]


FBeta : 0.6520	thresh : 0.26



Epoch 83/90  Mem : 15GB  LR : 2.47E-06  Loss: 0.3280: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3489: 100%|██████████| 58/58 [00:06<00:00,  8.36it/s]


FBeta : 0.6524	thresh : 0.06



Epoch 84/90  Mem : 15GB  LR : 2.08E-06  Loss: 0.3284: 100%|██████████| 642/642 [03:11<00:00,  3.35it/s]
Val Loss: 0.3482: 100%|██████████| 58/58 [00:06<00:00,  8.32it/s]


FBeta : 0.6531	thresh : 0.31



Epoch 85/90  Mem : 15GB  LR : 1.75E-06  Loss: 0.3295: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3495: 100%|██████████| 58/58 [00:06<00:00,  8.35it/s]


FBeta : 0.6516	thresh : 0.11



Epoch 86/90  Mem : 15GB  LR : 1.48E-06  Loss: 0.3292: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3488: 100%|██████████| 58/58 [00:06<00:00,  8.37it/s]


FBeta : 0.6521	thresh : 0.06



Epoch 87/90  Mem : 15GB  LR : 1.27E-06  Loss: 0.3291: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3473: 100%|██████████| 58/58 [00:06<00:00,  8.40it/s]


FBeta : 0.6533	thresh : 0.11



Epoch 88/90  Mem : 15GB  LR : 1.12E-06  Loss: 0.3281: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3475: 100%|██████████| 58/58 [00:06<00:00,  8.35it/s]


FBeta : 0.6530	thresh : 0.11



Epoch 89/90  Mem : 15GB  LR : 1.03E-06  Loss: 0.3286: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3473: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6530	thresh : 0.11



Epoch 90/90  Mem : 15GB  LR : 1.00E-06  Loss: 0.3282: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3476: 100%|██████████| 58/58 [00:06<00:00,  8.37it/s]


FBeta : 0.6532	thresh : 0.06



Epoch 91/90  Mem : 15GB  LR : 1.03E-06  Loss: 0.3269: 100%|██████████| 642/642 [03:11<00:00,  3.34it/s]
Val Loss: 0.3473: 100%|██████████| 58/58 [00:06<00:00,  8.38it/s]


FBeta : 0.6531	thresh : 0.06



Epoch 92/90  Mem : 15GB  LR : 1.12E-06  Loss: 0.3286: 100%|██████████| 642/642 [03:13<00:00,  3.32it/s]
Val Loss: 0.3487: 100%|██████████| 58/58 [00:06<00:00,  8.39it/s]


FBeta : 0.6524	thresh : 0.11



Epoch 93/90  Mem : 15GB  LR : 1.27E-06  Loss: 0.3284: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3473: 100%|██████████| 58/58 [00:06<00:00,  8.38it/s]


FBeta : 0.6533	thresh : 0.06



Epoch 94/90  Mem : 15GB  LR : 1.48E-06  Loss: 0.3277: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3477: 100%|██████████| 58/58 [00:06<00:00,  8.34it/s]


FBeta : 0.6530	thresh : 0.16



Epoch 95/90  Mem : 15GB  LR : 1.75E-06  Loss: 0.3284: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3486: 100%|██████████| 58/58 [00:07<00:00,  8.26it/s]


FBeta : 0.6522	thresh : 0.36



Epoch 96/90  Mem : 15GB  LR : 2.08E-06  Loss: 0.3295: 100%|██████████| 642/642 [03:12<00:00,  3.34it/s]
Val Loss: 0.3479: 100%|██████████| 58/58 [00:06<00:00,  8.35it/s]


FBeta : 0.6531	thresh : 0.16



Epoch 97/90  Mem : 15GB  LR : 2.47E-06  Loss: 0.3285: 100%|██████████| 642/642 [03:13<00:00,  3.32it/s]
Val Loss: 0.3490: 100%|██████████| 58/58 [00:06<00:00,  8.37it/s]


FBeta : 0.6525	thresh : 0.01



Epoch 98/90  Mem : 15GB  LR : 2.92E-06  Loss: 0.3270: 100%|██████████| 642/642 [03:11<00:00,  3.34it/s]
Val Loss: 0.3492: 100%|██████████| 58/58 [00:06<00:00,  8.34it/s]


FBeta : 0.6521	thresh : 0.16



Epoch 99/90  Mem : 15GB  LR : 3.42E-06  Loss: 0.3286: 100%|██████████| 642/642 [03:12<00:00,  3.33it/s]
Val Loss: 0.3477: 100%|██████████| 58/58 [00:06<00:00,  8.37it/s]


FBeta : 0.6533	thresh : 0.01

{"status":200,"message":"ok"}

In [8]:
pth = torch.load(f'./{CFG.exp_name}/{CFG.exp_name}.pth')

model = build_model(pth["model_arch"], pth["backbone"], pth["in_chans"], pth["target_size"], weight=None)
model.load_state_dict(pth['model'])
thresh = pth['thresh']
dice_score = pth['dice_score']
print(f"{dice_score:.4f}")

model.to(device)
model.eval();


model_arch:  Unet
backbone:  timm-resnest200e
0.6537


In [10]:
# for i, (images, masks) in enumerate(dataloader_valid):
#     if i!=1: continue
    
#     images, masks = images.cuda(), masks.cuda()
#     with torch.no_grad():
#         preds = model(images)
#         preds = torch.sigmoid(preds)
#     images, masks, preds = images.cpu(), masks.cpu(), preds.cpu()
    
#     for num in range(CFG.valid_batch_size):
#         fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(20,5))
#         axes = axes.flatten()
#         axes[0].imshow(images[num].permute(1,2,0))
#         axes[0].axis('off')
#         axes[0].set_title('Image')
#         axes[1].imshow(masks[num].permute(1,2,0))
#         axes[1].axis('off')
#         axes[1].set_title('Ground Truth')
#         axes[2].imshow(preds[num].permute(1,2,0))
#         axes[2].axis('off')
#         axes[2].set_title('pred')
#         axes[3].imshow((preds[num]>thresh).permute(1,2,0))
#         axes[3].axis('off')
#         axes[3].set_title('pred_thresh')
#     break