## setup

In [None]:
#!pip install kaggle

In [None]:
#!mkdir -p ~/.kaggle
#!cp /content/drive/MyDrive/Kaggle/kaggle.json ~/.kaggle/
#!chmod 600 ~/.kaggle/kaggle.json

In [None]:
#!kaggle competitions download -c blood-vessel-segmentation -p "/content/drive/MyDrive/Kaggle/SenNet + HOA - Hacking the Human Vasculature in 3D/DataSources"

In [None]:
import os
if not os.path.isdir('/content/train'):
    !unzip "/content/drive/MyDrive/Kaggle/SenNet + HOA - Hacking the Human Vasculature in 3D/DataSources/blood-vessel-segmentation.zip"

## library

In [None]:
!pip install timm
!pip install transformers
!pip install -U albumentations
!pip install segmentation_models_pytorch



In [None]:
import numpy as np
import pandas as pd

import gc
import os
import cv2
import glob
import random
import matplotlib.pyplot as plt
from tqdm import tqdm

import tifffile

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

from sklearn.model_selection import KFold


import timm

from transformers.optimization import get_cosine_schedule_with_warmup

import albumentations as A

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.decoders.unet.model import (
    UnetDecoder,
    SegmentationHead,
)

## utils

In [None]:
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

In [None]:
%run '/content/drive/MyDrive/Kaggle/SenNet + HOA - Hacking the Human Vasculature in 3D/metric.py'

In [None]:
# ref.: https://www.kaggle.com/code/iafoss/unet34-dice-0-87/notebook

def dice_score(pred, targs):
    pred = (pred>0).float()
    return 2.0 * (pred*targs).sum() / ((pred+targs).sum() + 1.0)

def IoU_score(pred, targs):
    pred = (pred>0).float()
    intersection = (pred*targs).sum()
    return intersection / ((pred+targs).sum() - intersection + 1.0)

## score function

In [None]:
# PyTorch version dependence on index data type
torch_ver_major = int(torch.__version__.split('.')[0])
dtype_index = torch.int32 if torch_ver_major >= 2 else torch.long

def compute_area(y: list, unfold: nn.Unfold, area: torch.Tensor) -> torch.Tensor:
    """
    Args:
      y (list[Tensor]): A pair of consecutive slices of mask
      unfold: nn.Unfold(kernel_size=(2, 2), padding=1)
      area (Tensor): surface area for 256 patterns (256, )

    Returns:
      Surface area of surface in 2x2x2 cube
    """
    # Two layers of segmentation masks
    yy = torch.stack(y, dim=0).to(torch.float16).unsqueeze(0)
    # (batch_size=1, nch=2, H, W)
    # bit (0/1) but unfold requires float

    # unfold slides through the volume like a convolution
    # 2x2 kernel returns 8 values (2 channels * 2x2)
    cubes_float = unfold(yy).squeeze(0)  # (8, n_cubes)

    # Each of the 8 values are either 0 or 1
    # Convert those 8 bits to one uint8
    cubes_byte = torch.zeros(cubes_float.size(1), dtype=dtype_index, device=device)
    # indices are required to be int32 or long for area[cube_byte] below, not uint8
    # Can be int32 for torch 2.0.0, int32 raise IndexError in torch 1.13.1.

    for k in range(8):
        cubes_byte += cubes_float[k, :].to(dtype_index) << k

    # Use area lookup table: pattern index -> area [float]
    cubes_area = area[cubes_byte]

    return cubes_area


def compute_surface_dice_score(submit: pd.DataFrame, label: pd.DataFrame) -> float:
    """
    Compute surface Dice score for one 3D volume

    submit (pd.DataFrame): submission file with id and rle
    label (pd.DataFrame): ground truth id, rle, and also image height, width
    """
    # submit and label must contain exact same id in same order
    assert (submit['id'] == label['id']).all()
    assert len(label) > 0

    # All height, width must be the same
    len(label['height'].unique()) == 1
    len(label['width'].unique()) == 1

    # Surface area lookup table: Tensor[float32] (256, )
    area = create_table_neighbour_code_to_surface_area((1, 1, 1))
    area = torch.from_numpy(area).to(device)  # torch.float32

    # Slide through the volume like a convolution
    unfold = torch.nn.Unfold(kernel_size=(2, 2), padding=1)

    r = label.iloc[0]
    h, w = r['height'], r['width']
    n_slices = len(label)

    # Padding before first slice
    y0 = y0_pred = torch.zeros((h, w), dtype=torch.uint8, device=device)

    num = 0     # numerator of surface Dice
    denom = 0   # denominator
    for i in range(n_slices + 1):
        # Load one slice
        if i < n_slices:
            r = label.iloc[i]
            y1 = rle_decode(r['rle'], (h, w))
            y1 = torch.from_numpy(y1).to(device)

            r = submit.iloc[i]
            y1_pred = rle_decode(r['rle'], (h, w))
            y1_pred = torch.from_numpy(y1_pred).to(device)
        else:
            # Padding after the last slice
            y1 = y1_pred = torch.zeros((h, w), dtype=torch.uint8, device=device)

        # Compute the surface area between two slices (n_cubes,)
        area_pred = compute_area([y0_pred, y1_pred], unfold, area)
        area_true = compute_area([y0, y1], unfold, area)

        # True positive cube indices
        idx = torch.logical_and(area_pred > 0, area_true > 0)

        # Surface dice numerator and denominator
        num += area_pred[idx].sum() + area_true[idx].sum()
        denom += area_pred.sum() + area_true.sum()

        # Next slice
        y0 = y1
        y0_pred = y1_pred

    dice = num / denom.clamp(min=1e-8)
    return dice.item()

In [None]:
def test_function(df,
                  model,
                  device,
                  threshold,
                  axes=[0,1,2]):

    model.eval()


    ids = list(df.index)
    slices = list(df['slice'])
    groups = list(df['group'])
    widths = [img_size] * len(ids)
    heights = [img_size] * len(ids)

    '''
    trues = []
    preds = []

    for bi, sample in enumerate(tqdm(loader)):
        img = sample[0].to(device)
        mask = sample[1].numpy()
        hw = sample[2].tolist()[0]

        #assert mask.shape == (1, img_size, img_size)
        true = rle_encode(mask[0])

        with torch.no_grad():
            pred = model(img)

        pred = torchvision.transforms.Resize(hw, antialias=True)(pred)
        pred = (pred>threshold).float().cpu().numpy()
        pred = rle_encode(pred[0])

        if len(pred) == 0:
            pred = "1 1"

        trues.append(true)
        preds.append(pred)
    '''

    df = df.reset_index(drop=True)
    imgs, masks = preload(df)

    pred = np.zeros(imgs.shape, dtype=np.float16)
    for axis in axes:
        gc.collect()

        for index in tqdm(range(imgs.shape[axis])):
            img, mask, hw = process_img(imgs, masks, index=index, axis=axis)
            img = img[None].to(device)

            with torch.no_grad():
                logit = model(img)
                #logit2 = model(img.flip(3)).flip(2)
                #logit = (logit + logit2)/2

            logit = torchvision.transforms.Resize(hw, antialias=True)(logit).cpu().numpy()[0]
            if axis==0:
                pred[index, :, :] += logit
            elif axis==1:
                pred[:, index, :] += logit
            elif axis==2:
                pred[:, :, index] += logit
            else:
                raise NotImplementedError()

    pred = pred/len(axes)
    pred = (pred>threshold).astype(np.uint8)

    preds = []
    for i in range(pred.shape[0]):
        rle = rle_encode(pred[i])
        if len(rle)==0:
            rle='1 1'
        preds.append(rle)

    solution = pd.DataFrame({
        'id': ids,
        #'rle': trues,
        'width': widths,
        'height': heights,
        'group': groups,
        'slice': slices,
        })

    widths = []
    heights = []
    for i in range(len(df)):
        sample = df.loc[i]

        dataset_id, slice_id = sample['dataset'], sample['slice']
        img = tifffile.imread(f'/content/train/{dataset_id}/images/{slice_id}.tif')
        h, w = img.shape
        widths.append(w)
        heights.append(h)

    solution['width'] = widths
    solution['height'] = heights
    solution['rle'] = list(df['rle'])

    submission = pd.DataFrame({
        'id': ids,
        'rle': preds,
        })

    return solution, submission

## seed

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

if __name__ == "__main__":
  seed_everything(42)

## preprocess

In [None]:
# 3fold : per donor

def preprocess():
    train_rles = pd.read_csv('/content/train_rles.csv')

    datasets = []
    slices = []
    groups = []
    for i in range(len(train_rles)):
        id, rle = train_rles.loc[i]

        dataset_id = '_'.join(id.split('_')[:-1])
        slice_id = id.split('_')[-1]
        group_id = '_'.join(id.split('_')[:-1])

        if dataset_id == 'kidney_3_dense':
            dataset_id = 'kidney_3_sparse'

        datasets.append(dataset_id)
        slices.append(slice_id)
        groups.append(group_id)

    train_rles['dataset'] = datasets
    train_rles['slice'] = slices
    train_rles['group'] = groups

    folds = []
    for i in range(3):
        train_df = train_rles[~train_rles['dataset'].str.contains(f'{i+1}')].reset_index(drop=True)
        val_df = train_rles[train_rles['dataset'].str.contains(f'{i+1}')].reset_index(drop=True)
        folds.append([train_df, val_df])

    return train_rles, folds

if __name__ == "__main__":
    train_rles, folds = preprocess()

## dataset

In [None]:
def preload(df):
    imgs = []
    masks = []
    for i in tqdm(range(len(df))):
        _, rle, dataset_id, slice_id, _ = df.loc[i]

        img = tifffile.imread(f'/content/train/{dataset_id}/images/{slice_id}.tif')
        mask = rle_decode(rle, img.shape)

        imgs.append(img)
        masks.append(mask)
    return np.stack(imgs), np.stack(masks)

if __name__ == "__main__":

    train_df, val_df = folds[2]

    # train:kidney_1_dense
    train_df = train_df[train_df['group']=='kidney_1_dense'].reset_index(drop=True)
    imgs, masks = preload(train_df)

In [None]:
device = 'cuda'

In [None]:
# 2d

'''train_transform = A.Compose(
    [
        A.Resize(int(img_size*1.125), int(img_size*1.125)),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=25, p=0.5),
        A.RandomCrop(
            always_apply=False, p=1.0, height=img_size, width=img_size
        ),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.5
        ),
        A.Cutout(num_holes=8, max_h_size=int((img_size/512)*36), max_w_size=int((img_size/512)*36), p=0.8),
    ]
)'''

train_transform = A.Compose(
    [
        #A.Resize(img_size, img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Cutout(num_holes=8, max_h_size=128, max_w_size=128, p=0.8),
    ]
)

val_transform = A.Compose(
    [
        #A.Resize(img_size, img_size),
    ]
)

def blur_augmentation(x):
    h, w, _ = x.shape
    scale = np.random.uniform(0.5, 1.5)

    x = A.Resize(int(h*scale), int(w*scale))(image=x)['image']
    x = A.Resize(h, w)(image=x)['image']
    return x

def channel_augmentation(x, prob=0.5, n_channel=3):
    assert x.shape[2]==n_channel
    if np.random.rand()<prob:
        x = np.flip(x, axis=2)
    return x


class CustomDatasetV2(torch.utils.data.Dataset):
    def __init__(self, imgs, masks, axis, transform, training=False):
        self.imgs = imgs
        self.masks = masks
        self.axis = axis
        self.transform = transform
        self.training = training

        self.n_slice = 1
        self.n_stride = 1

    def __len__(self):
        return self.imgs.shape[self.axis]

    def normalize_img(self, img):
        img = img - np.min(img)
        img = img / np.max(img)
        img = (img*255).astype(np.uint8)
        return img

    def process_img(self, index):
        img = []
        for i in range(-self.n_slice*self.n_stride, self.n_slice*self.n_stride+1, self.n_stride):
            i = i + index
            try:
                assert i >= 0 and i <= self.imgs.shape[self.axis]-1
                if self.axis==0:
                    x = self.imgs[i]
                elif self.axis==1:
                    x = self.imgs[:, i]
                elif self.axis==2:
                    x = self.imgs[:, :, i]
                else:
                    raise NotImplementedError()

                x = self.normalize_img(x)
                img.append(x)
            except:
                if self.axis==0:
                    x = np.zeros_like(self.imgs[0])
                elif self.axis==1:
                    x = np.zeros_like(self.imgs[:, 0])
                elif self.axis==2:
                    x = np.zeros_like(self.imgs[:, :, 0])
                else:
                    raise NotImplementedError()

                x = x.astype(np.uint8)
                img.append(x)

        img = np.stack(img, axis=-1)
        return img

    def process_mask(self, index):
        if self.axis==0:
            mask = self.masks[index]
        elif self.axis==1:
            mask = self.masks[:, index]
        elif self.axis==2:
            mask = self.masks[:, :, index]
        else:
            raise NotImplementedError()
        return mask

    def __getitem__(self, index):
        img = self.process_img(index)
        mask = self.process_mask(index)

        hw = torch.tensor(img.shape[:2])

        if self.training:
            img = blur_augmentation(img)
            #img = channel_augmentation(img)

        transforms = self.transform(image=img, mask=mask)
        img, mask = transforms['image'], transforms['mask']

        img = torch.tensor(img, dtype = torch.float)
        img = img.permute(2, 0, 1) / 255.0
        mask = torch.tensor(mask, dtype = torch.float)

        return img, mask, hw

if __name__ == "__main__":

    train_df, val_df = folds[2]

    # train:kidney_1_dense
    train_df = train_df[train_df['group']=='kidney_1_dense'].reset_index(drop=True)
    #imgs, masks = preload(train_df)

    ds = CustomDatasetV2(imgs, masks, 1, train_transform, training=True)
    index = np.random.randint(0, len(ds)-1)

    img, mask, hw = ds[index]
    fig, axs = plt.subplots(1, 5, figsize=(20, 4))
    axs[0].imshow(img.permute(1, 2, 0), cmap='gray')
    axs[1].imshow(img[0], cmap='gray')
    axs[2].imshow(img[1], cmap='gray')
    axs[3].imshow(img[2], cmap='gray')
    axs[4].imshow(mask, cmap='gray')
    plt.show()
    print(hw)

## unet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        use_batchnorm=True,
        attention_type=None,
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None, size=None):
        #x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = F.interpolate(x, size=skip.shape[2:], mode="nearest")
        else:
            x = F.interpolate(x, size=size, mode="nearest")

        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = md.Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)


class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        n_blocks=5,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        if center:
            self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):
        original_size = features[0].shape[2:]

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]

        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip, original_size)

        return x

tensor([1303,  912])


## model

In [None]:
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()

        self.n_classes = 1#len(cfg.classes)
        in_chans = 3

        self.encoder = timm.create_model(
            'regnety_016',#cfg.backbone,
            pretrained=True,#cfg.pretrained,
            features_only=True,
            in_chans=in_chans,
        )
        encoder_channels = tuple(
            [in_chans]
            + [
                self.encoder.feature_info[i]["num_chs"]
                for i in range(len(self.encoder.feature_info))
            ]
        )
        self.decoder = UnetDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            n_blocks=5,
            use_batchnorm=True,
            center=False,
            attention_type=None,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=16,
            out_channels=self.n_classes,
            activation=None,
            kernel_size=3,
        )

        self.train_loss = smp.losses.TverskyLoss(mode='binary', alpha=0.1, beta=0.9)#DiceBCELoss()#smp.losses.DiceLoss(mode='binary')#nn.BCEWithLogitsLoss()
        self.test_loss = smp.losses.DiceLoss(mode='binary')#nn.BCEWithLogitsLoss()

        #self.return_logits = cfg.return_logits

    def forward(self, batch, training=False):

        x_in = batch["input"]

        enc_out = self.encoder(x_in)

        decoder_out = self.decoder(*[x_in] + enc_out)
        x_seg = self.segmentation_head(decoder_out)

        output = {}
        #if (not self.training) & self.return_logits:
        #    output["logits"] = x_seg

        #if self.training:
        #if self.n_classes > 1:
        #    one_hot_mask = F.one_hot(
        #        batch["mask"].long(), num_classes=self.n_classes + 1
        #    ).permute(0, 3, 1, 2)[:, 1:]
        #else:
        one_hot_mask = batch["mask"][:, None]
        if training:
            loss = self.train_loss(x_seg, one_hot_mask.float())
        else:
            loss = self.test_loss(x_seg, one_hot_mask.float())

        output["loss"] = loss
        output['logit'] = nn.Sigmoid()(x_seg)[:, 0]

        return output

if __name__ == "__main__":
    train_df, val_df = folds[2]
    ds = CustomDatasetV2(imgs, masks, 2, train_transform, training=True)
    loader = torch.utils.data.DataLoader(ds, batch_size = 1, num_workers = 8, shuffle = True, drop_last = True)
    sample = next(iter(loader))
    sample = [x.to(device) for x in sample]

    batch = {}
    batch['input'] = sample[0]
    batch['mask'] = sample[1]

    model = CustomModel().to(device)

    with torch.no_grad():
        output = model(batch, training=True)
        print(output)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


{'loss': tensor(1.8375, device='cuda:0'), 'logit': tensor([[[0.5411, 0.4949, 0.5671,  ..., 0.5274, 0.5286, 0.5220],
         [0.5723, 0.6003, 0.5729,  ..., 0.5545, 0.5126, 0.4545],
         [0.4763, 0.5992, 0.5767,  ..., 0.6058, 0.4819, 0.4250],
         ...,
         [0.4441, 0.5105, 0.5095,  ..., 0.4941, 0.5136, 0.4351],
         [0.5017, 0.4980, 0.5330,  ..., 0.5333, 0.4562, 0.4438],
         [0.5514, 0.5764, 0.5779,  ..., 0.5773, 0.6365, 0.5588]]],
       device='cuda:0')}


## train

In [None]:
def train_function(model,
                   optimizer,
                   scheduler,
                   scaler,
                   loader,
                   device,
                   iters_to_accumulate):
    model.train()

    gc.collect()

    total_loss = 0.0
    for bi, sample in enumerate(tqdm(loader)):
        sample = [x.to(device) for x in sample]

        batch = {}

        batch['input'] = sample[0]
        batch['mask'] = sample[1]

        with torch.cuda.amp.autocast():
            loss = model(batch, training=True)['loss']
            loss = loss / iters_to_accumulate

        scaler.scale(loss).backward()
        if (bi + 1) % iters_to_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            scheduler.step()

        total_loss += loss.detach().cpu() * iters_to_accumulate

    return total_loss/len(loader)

def val_function(model,
                 scaler,
                 loader,
                 device,
                 log_path,
                 threshold=0.5):
    model.eval()

    gc.collect()

    total_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0
    for bi, sample in enumerate(tqdm(loader)):
        sample = [x.to(device) for x in sample]

        batch = {}

        batch['input'] = sample[0]
        batch['mask'] = sample[1]

        with torch.no_grad():
            output = model(batch)
            loss = output['loss']
            logit = output['logit']

        dice = dice_score(logit>threshold, sample[1])
        iou = IoU_score(logit>threshold, sample[1])


        total_loss += loss.detach().cpu()
        total_dice += dice.detach().cpu()
        total_iou += iou.detach().cpu()

    message = {
        'bce_loss' : round(total_loss.tolist()/len(loader), 4),
        'dice_score' : round(total_dice.tolist()/len(loader), 4),
        'iou_score' : round(total_iou.tolist()/len(loader), 4)
    }

    with open(log_path, 'a+') as logger:
        logger.write(f'{message}\n')

    return message

## run

In [None]:
try:del imgs, masks, ds, loader
except:pass

train, folds = preprocess()
train_df, val_df = folds[k]

for k in range(2, 3):

    batch_size = 4
    epoch = 20
    early_stop = 20
    lr = 2e-4
    wd = 0.01
    warmup_ratio = 0.1
    num_workers = 8
    iters_to_accumulate = 4
    train_dir = 'model:unet-regnety016,loss:tversky19,imgsize:original,train:kidney_1_dense,aug:flip+cutout+blur+xyz,channel:3'
    seed = 42
    root = '/content/drive/MyDrive/Kaggle/SenNet + HOA - Hacking the Human Vasculature in 3D/'

    seed_everything(seed)

    # train:kidney_1_dense
    train_df = train_df[train_df['dataset']=='kidney_1_dense'].reset_index(drop=True)

    train_imgs, train_masks = preload(train_df)
    val_imgs, val_masks = preload(val_df)

    train_dataset0 = CustomDatasetV2(train_imgs, train_masks, 2, train_transform, training=True)
    train_dataset1 = CustomDatasetV2(train_imgs, train_masks, 1, train_transform, training=True)
    train_dataset2 = CustomDatasetV2(train_imgs, train_masks, 0, train_transform, training=True)
    val_dataset = CustomDatasetV2(val_imgs, val_masks, 0, val_transform, training=False)#CustomDataset(val_df, val_transform)

    train_loader0 = torch.utils.data.DataLoader(train_dataset0, batch_size = batch_size, num_workers = num_workers, shuffle = True, drop_last = True)
    train_loader1 = torch.utils.data.DataLoader(train_dataset1, batch_size = batch_size, num_workers = num_workers, shuffle = True, drop_last = True)
    train_loader2 = torch.utils.data.DataLoader(train_dataset2, batch_size = batch_size, num_workers = num_workers, shuffle = True, drop_last = True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = False, drop_last = False)

    model = CustomModel().to(device)

    optimizer = torch.optim.AdamW(params = model.parameters(), lr = lr, weight_decay = wd)
    total_steps = int((len(train_dataset0)+len(train_dataset1)+len(train_dataset2)) * epoch/(batch_size * iters_to_accumulate))
    warmup_steps = int(total_steps * warmup_ratio)
    print('total_steps: ', total_steps)
    print('warmup_steps: ', warmup_steps)

    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps = warmup_steps,
                                                num_training_steps = total_steps)
    scaler = torch.cuda.amp.GradScaler()

    if not os.path.isdir(root + f'{train_dir}/'):
        os.mkdir(root + f'{train_dir}/')

    if not os.path.isdir(root + f'{train_dir}/fold{k+1}/'):
        os.mkdir(root + f'{train_dir}/fold{k+1}/')


    for i in range(epoch):
        # train0
        train_loss0 = train_function(model,
                                     optimizer,
                                     scheduler,
                                     scaler,
                                     train_loader0,
                                     device,
                                     iters_to_accumulate)
        # train1
        train_loss1 = train_function(model,
                                     optimizer,
                                     scheduler,
                                     scaler,
                                     train_loader1,
                                     device,
                                     iters_to_accumulate)
        # train2
        train_loss2 = train_function(model,
                                     optimizer,
                                     scheduler,
                                     scaler,
                                     train_loader2,
                                     device,
                                     iters_to_accumulate)

        train_loss = (train_loss0 + train_loss1 + train_loss2)/3
        # val
        message = val_function(model,
                               scaler,
                               val_loader,
                               device,
                               root + f'{train_dir}/fold{k+1}/log.txt')

        val_loss, val_dice, val_iou = message['bce_loss'], message['dice_score'], message['iou_score']


        # save
        save_path = root + f'{train_dir}/fold{k+1}/epoch' + f'{i+1}'.zfill(3) + \
                    f'-trainloss{round(train_loss.tolist(), 4)}' + \
                    f'-valloss{val_loss}' + \
                    f'-valdice{val_dice}' + \
                    f'-valiou{val_iou}' + '.bin'
        torch.save(model.state_dict(), save_path)

        _lr = optimizer.param_groups[0]['lr']
        print(f'epoch : {i+1}, lr : {_lr}, trainloss : {round(train_loss.tolist(), 4)}, valloss : {val_loss}, valdice : {val_dice}, valiou : {val_iou}')

        if i+1 == early_stop:
            break

In [None]:
#from google.colab import runtime
#runtime.unassign()

In [None]:
#[0]*10**10