In [1]:
!pip install volumentations-3D
!pip install segmentation_models_pytorch

Collecting volumentations-3D
  Obtaining dependency information for volumentations-3D from https://files.pythonhosted.org/packages/59/a6/af48582a42ad57eb7bb6517802c2fb07eaa3a715c6db9d9c356870795bca/volumentations_3D-1.0.4-py3-none-any.whl.metadata
  Downloading volumentations_3D-1.0.4-py3-none-any.whl.metadata (481 bytes)
Downloading volumentations_3D-1.0.4-py3-none-any.whl (31 kB)
Installing collected packages: volumentations-3D
Successfully installed volumentations-3D-1.0.4
Collecting segmentation_models_pytorch
  Obtaining dependency information for segmentation_models_pytorch from https://files.pythonhosted.org/packages/cb/70/4aac1b240b399b108ce58029ae54bc14497e1bbc275dfab8fd3c84c1e35d/segmentation_models_pytorch-0.3.3-py3-none-any.whl.metadata
  Downloading segmentation_models_pytorch-0.3.3-py3-none-any.whl.metadata (30 kB)
Collecting pretrainedmodels==0.7.4 (from segmentation_models_pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from __future__ import annotations
# ====================================================
# Directory settings
# ====================================================
import os
from pathlib import Path

OUTPUT_DIR = '/kaggle/working'

INPUT_DIR='/kaggle/input/blood-vessel-segmentation'

In [3]:
# ====================================================
# CFG
# ====================================================
class CFG:
    apex=True
    wandb = False
    competition = "HOA3D"
    backbone = 'resnet18d'
    max_grad_norm = 1000
    debug = False
    debug_train_size = 200
    scheduler = "cosine"
    num_warmup_steps=0
    epochs = 2
    decoder_lr = 0.005
    betas = (0.9, 0.999)
    batch_size = 4
    infer_batch_size = 4
    ckpt_name = 'unet3d-baseline'
    weight_decay = 0.1
    seed = 42
    print_freq=50
    eval_freq =1000
    eval_step_save_start_epoch=0
    train = True

if CFG.debug:
    CFG.epochs = 1

In [4]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import itertools
import glob
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset
import timm
import segmentation_models_pytorch as smp

from PIL import Image
import cv2

from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed=42)

# Dataset

In [7]:
def extract_3d_voxels_for_patches(image_3d, patch_size=(128, 128, 128), stride=(64, 64, 64)):
    """
    Extracts patches in an image and returns a list of lowest voxels.
    """
    patches = []

    for start_x in range(0, image_3d.shape[0] - patch_size[0], stride[0]):
        for start_y in range(0, image_3d.shape[1] - patch_size[1], stride[1]):
            for start_z in range(0, image_3d.shape[2] - patch_size[2], stride[2]):
                # Lowest voxel needed to extract patch
                lowest_voxel = (start_x,start_y,start_z)
                patches.append(lowest_voxel)
                
    return patches

def extract_patch_from_voxel(image_3d, lowest_voxel, patch_size=(128, 128, 128)):
    """
    Extracts a 3d patch given the lowest voxel.
    """
    end_x = lowest_voxel[0] + patch_size[0] 
    end_y = lowest_voxel[1] + patch_size[1]
    end_z = lowest_voxel[2] + patch_size[2]

    patch = image_3d[lowest_voxel[0]:end_x,
                     lowest_voxel[1]:end_y,
                     lowest_voxel[2]:end_z]

    return patch



def filter_empty_patches_by_voxel(lowest_voxels, mask, patch_size, threshold = 0):
    """
    Removes the lowest voxels that represent patches with positive values <= threshold
    """
    positive_voxels = []

    for lowest_voxel in tqdm(lowest_voxels):
        # Extract the patch
        patch = mask[lowest_voxel[0]:lowest_voxel[0]  + patch_size[0],
                         lowest_voxel[1]:lowest_voxel[1]  + patch_size[1],
                         lowest_voxel[2]:lowest_voxel[2]  + patch_size[2]]

        if np.count_nonzero(patch) > threshold:
            positive_voxels.append(lowest_voxel)


    return positive_voxels

In [8]:
class ValidKidney3DDataset(Dataset):
    def __init__(self, patches, masks,patch_size,stride_size, transformations=None):
        self.patches = patches
        self.masks = masks
        self.patch_size = patch_size
        self.lowest_voxels = extract_3d_voxels_for_patches(patches, patch_size=patch_size, stride=stride_size)
        self.lowest_voxels = filter_empty_patches_by_voxel(self.lowest_voxels, masks, patch_size,threshold=0)
        self.transformations = transformations
        
    def __len__(self):
        return len(self.lowest_voxels)

    def __getitem__(self, idx):
        image = extract_patch_from_voxel(self.patches,self.lowest_voxels[idx],patch_size=self.patch_size)
        mask = extract_patch_from_voxel(self.masks,self.lowest_voxels[idx],patch_size=self.patch_size)
        
        image = np.expand_dims(image,0)
        mask = np.expand_dims(mask,0)
        
        data = {'image':image,'mask': mask}
        if self.transformations:
            data = self.transformations(**data)
            
        data['image']=(data['image'] - data['image'].mean()) / (data['image'].std() + 0.0001)

        data['image'] = torch.tensor(data['image'], dtype=torch.float)
        data['mask'] = torch.tensor(data['mask'], dtype=torch.float)
        return data

In [10]:
class TrainKidney3DDataset(Dataset):
    def __init__(self, patches, masks, lowest_voxels, patch_size, transformations=None):
        self.patches = patches
        self.masks = masks
        self.lowest_voxels = lowest_voxels
        self.patch_size = patch_size
        self.transformations = transformations
        
    def __len__(self):
        return len(self.lowest_voxels)

    def __getitem__(self, idx):
        image = extract_patch_from_voxel(self.patches,self.lowest_voxels[idx],patch_size=self.patch_size)
        mask = extract_patch_from_voxel(self.masks,self.lowest_voxels[idx], patch_size=self.patch_size)
        
        image = np.expand_dims(image,0)
        mask = np.expand_dims(mask,0)
    
        data = {'image':image,'mask': mask}
        if self.transformations:
            data = self.transformations(**data)
        
        
        data['image']=(data['image'] - data['image'].mean()) / (data['image'].std() + 0.0001)
        data['image'] = torch.tensor(data['image'], dtype=torch.float)
        data['mask'] = torch.tensor(data['mask'], dtype=torch.float)
        return data

In [11]:
import volumentations as volumen

def get_augmentation(patch_size):
    return volumen.Compose([
        volumen.RandomGamma(gamma_limit=(80, 120), p=0.3),
        volumen.GaussianNoise(var_limit=(0, 5), p=0.3),
        volumen.Flip(1, p=0.3),
        volumen.Flip(2, p=0.3),
    ], p=1.0)

In [13]:
patch_size = (128,128,128)
stride = (32,32,32)

In [None]:

# npz version of data
kidney3 = np.load('/kaggle/input/sennet-hoa-all-kidneys-dense/kidney_3_dense.npz')
kidney3_volume = kidney3['volume'].astype(np.uint8)
kidney3_masks = kidney3['mask'].astype(np.uint8)

kidney1 = np.load('/kaggle/input/sennet-hoa-all-kidneys-dense/kidney_1_dense.npz')
kidney1_volume = kidney1['volume'].astype(np.uint8)
kidney1_masks = kidney1['mask'].astype(np.uint8)

In [None]:
my_voxels = extract_3d_voxels_for_patches(kidney1_volume,patch_size=patch_size,stride=(32,32,32))
my_voxels = filter_empty_patches_by_voxel(my_voxels,kidney1_masks,patch_size=patch_size,threshold=50)
np.save('/kaggle/working/kidney1_lowest_voxel_patches_stride32_thresh50.npy', my_voxels)

In [15]:
# patch size 128x128x128 stride 32x32x32, pos filter thres 50
loaded_lowest_voxels = np.load('/kaggle/working/kidney1_lowest_voxel_patches_stride32_thresh50.npy') 
loaded_lowest_voxels.shape

(36872, 3)

In [16]:
train_dataset = TrainKidney3DDataset(
    patches=kidney1_volume,
    masks=kidney1_masks,
    lowest_voxels=loaded_lowest_voxels,
    patch_size=patch_size,
    transformations=get_augmentation(patch_size),
)

In [17]:
len(train_dataset)

36872

# Model

Adapted From First Stage RSNA 2022 Cervical Spine Fracture Detection Qishen Hai

In [14]:
n_blocks = 4
out_dim = 1
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=1,
            features_only=True,
            drop_rate=0,
            drop_path_rate=0,
            pretrained=pretrained
        )
        
        
        g = self.encoder(torch.rand(1, 1, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        print(encoder_channels)
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

In [15]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List


# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)


# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0


# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
    ih, iw, iz = x.size()[-3:]
    pad_h = get_same_padding(ih, k[0], s[0], d[0])
    pad_w = get_same_padding(iw, k[1], s[1], d[1])
    pad_z = get_same_padding(iz, k[2], s[2], d[2])
    if pad_h > 0 or pad_w > 0 or pad_z > 0:
        x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_z // 2, pad_z - pad_z // 2], value=value)
    return x


def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
    dynamic = False
    if isinstance(padding, str):
        # for any string padding, the padding will be calculated for you, one of three ways
        padding = padding.lower()
        if padding == 'same':
            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
            if is_static_pad(kernel_size, **kwargs):
                # static case, no extra overhead
                padding = get_padding(kernel_size, **kwargs)
            else:
                # dynamic 'SAME' padding, has runtime/GPU memory overhead
                padding = 0
                dynamic = True
        elif padding == 'valid':
            # 'VALID' padding, same as padding=0
            padding = 0
        else:
            # Default to PyTorch style 'same'-ish symmetric padding
            padding = get_padding(kernel_size, **kwargs)
    return padding, dynamic


def conv3d_same(
        x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int, int] = (1, 1, 1),
        padding: Tuple[int, int, int] = (0, 0, 0), dilation: Tuple[int, int, int] = (1, 1, 1), groups: int = 1):
    x = pad_same(x, weight.shape[-3:], stride, dilation)
    return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)


class Conv3dSame(nn.Conv3d):
    """ Tensorflow like 'SAME' convolution wrapper for 3d convolutions
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv3dSame, self).__init__(
            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)

    def forward(self, x):
        return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


def create_conv3d_pad(in_chs, out_chs, kernel_size, **kwargs):
    padding = kwargs.pop('padding', '')
    kwargs.setdefault('bias', False)
    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
    if is_dynamic:
        return Conv3dSame(in_chs, out_chs, kernel_size, **kwargs)
    else:
        return nn.Conv3d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

In [16]:
from timm.layers.conv2d_same import Conv2dSame

def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

# Loss

In [17]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        inputs = F.sigmoid(inputs)       
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice
    

class BCELoss(nn.Module):
    def __init__(self, smooth=1.0, pos_weight=1, device='cpu'):
        super(BCELoss, self).__init__()
        self.smooth = smooth
        self.pos_weight = pos_weight
        self.device = device

    def forward(self, inputs, targets):
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        pos_weight = torch.tensor([self.pos_weight]).to(self.device)
        loss = F.binary_cross_entropy_with_logits(inputs, targets,pos_weight=pos_weight)
        return loss
    

class FocalLoss(nn.modules.loss._WeightedLoss):

    def __init__(self, gamma=0, size_average=None, ignore_index=-100,
                 reduce=None, balance_param=1.0):
        super(FocalLoss, self).__init__(size_average)
        self.gamma = gamma
        self.size_average = size_average
        self.ignore_index = ignore_index
        self.balance_param = balance_param

    def forward(self, input, target):
        logpt = - F.binary_cross_entropy_with_logits(input, target)
        pt = torch.exp(logpt)

        focal_loss = -((1 - pt) ** self.gamma) * logpt
        balanced_focal_loss = self.balance_param * focal_loss
        return balanced_focal_loss
    
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, pos_weight=1, device='cpu'):
        super(DiceBCELoss, self).__init__()
        self.pos_weight = pos_weight
        self.device = device
        
    def forward(self, inputs, targets, smooth=1):
        
        inputs = F.sigmoid(inputs)       
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        pos_weight = torch.tensor([self.pos_weight]).to(self.device)
        bce_loss = F.binary_cross_entropy(inputs, targets, pos_weight=pos_weight,reduction='mean')
        Dice_BCE = (bce_loss + dice_loss)/2
        
        return Dice_BCE

# Train

In [18]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

def train_fn(
    train_loader: DataLoader,
    valid_loader: DataLoader,
    model: nn.Module,
    criterion: nn.Module,
    optimizer,
    epoch: int,
    scheduler,
    device,
    best_score,
) -> torch.Tensor:


    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    for step, (batch) in enumerate(train_loader):
        
        
        inputs = batch['image'].to(device)
        labels = batch['mask'].to(device)

        batch_size = labels.size(0)
        
        with torch.cuda.amp.autocast(enabled=CFG.apex):
            y_preds = model(inputs)
            loss = criterion(y_preds, labels)
            
        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(
            model.parameters(), CFG.max_grad_norm
        )

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        global_step += 1
        scheduler.step()
 
        end = time.time()
        
        if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
            print(
                "Epoch: [{0}][{1}/{2}] "
                "Elapsed {remain:s} "
                "Loss: {loss.val:.4f}({loss.avg:.4f}) "
                "Grad: {grad_norm:.4f}  "
                "LR: {lr:.8f}  ".format(
                    epoch + 1,
                    step,
                    len(train_loader),
                    remain=timeSince(start, float(step + 1) / len(train_loader)),
                    loss=losses,
                    grad_norm=grad_norm,
                    lr=scheduler.get_lr()[0],
                )
            )
            
            
                
        if CFG.eval_step_save_start_epoch <= epoch and (
            (step + 1) % CFG.eval_freq == 0
        ):
            val_loss = valid_fn(valid_loader, model, criterion, device)
            score = val_loss
            if score < best_score:
                best_score = score
                save_ckpt(model)
                print(f"Saving New Best Score Model")
            
    torch.cuda.empty_cache()
    gc.collect()

    return losses.avg, best_score


@torch.inference_mode()
def valid_fn(
    valid_loader: DataLoader, model: nn.Module, criterion: nn.Module, device
) -> tuple[torch.Tensor, np.ndarray]:
    
    losses = AverageMeter()
    model.eval()
    start = end = time.time()
    for step, (batch) in enumerate(valid_loader):
        inputs = batch['image'].to(device)
        labels = batch['mask'].to(device)
        batch_size = labels.size(0)
        
        y_preds = model(inputs)
            
        loss = criterion(y_preds, labels)
    
            
        losses.update(loss.item(), batch_size)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader) - 1):
            print(
                "EVAL: [{0}/{1}] "
                "Elapsed {remain:s} "
                "Loss: {loss.val:.4f}({loss.avg:.4f}) ".format(
                    step,
                    len(valid_loader),
                    loss=losses,
                    remain=timeSince(start, float(step + 1) / len(valid_loader)),
                )
            )


    model.train()
    return losses.avg

def save_ckpt(
    model: torch.nn.Module,
) -> None:

    save_path = OUTPUT_DIR + f'/{CFG.ckpt_name}.pth'

    torch.save(
        {"model": model.state_dict()},
        save_path,
    )

In [19]:
model = TimmSegModel(CFG.backbone)
model = convert_3d(model)

model.to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {total_params}")

[1, 64, 64, 128, 256, 512]
Trainable parameters: 39722049


In [None]:
aug = get_augmentation(patch_size)
train_dataset = TrainKidney3DDataset(
    patches=kidney1_volume,
    masks=kidney1_masks,
    lowest_voxels=loaded_lowest_voxels,
    patch_size=patch_size,
    transformations=get_augmentation(patch_size),
)

valid_dataset = ValidKidney3DDataset(
    patches=kidney3_volume,
    masks=kidney3_masks,
    patch_size=patch_size,
    stride_size=(64, 64, 64),
)


train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=CFG.infer_batch_size, shuffle=False)

# Loss function and optimizer
# criterion = BCELoss(pos_weight=120,device=device)
criterion = DiceLoss()
optimizer = AdamW(model.parameters(), lr=CFG.decoder_lr)
num_train_steps = int(len(train_dataset) / CFG.batch_size * CFG.epochs)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_train_steps//2)

In [27]:
num_train_steps

18436

In [28]:
best_score = np.inf
for epoch in range(CFG.epochs):
    start_time = time.time()
    avg_loss, best_score = train_fn(
    train_loader,
    valid_loader,
    model,
    criterion,
    optimizer,
    epoch,
    scheduler,
    device,
    best_score,
    )
            
            
    avg_val_loss = valid_fn(valid_loader, model, criterion, device)
    elapsed = time.time() - start_time
    save_ckpt(model)

Epoch: [1][0/9218] Elapsed 0m 8s (remain 1256m 57s) Loss: 0.9894(0.9894) Grad: 16459.8984  LR: 0.00500000  


KeyboardInterrupt: 

In [34]:
gc.collect()
torch.cuda.empty_cache()