# Spinal chord segmentation using UNET
## This notebook has the follwing topics

1. Dataset preprocessing
2. Create dataloader¶
3. UNET architecture
4. DICE loss
5. Training the model 


In [1]:
from PIL import Image
from matplotlib import pyplot as plt
from glob import glob as glob
import numpy as np
import cv2
import random
import os
import time
import copy
from collections import defaultdict
from tqdm import tqdm as tqdm

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models
import torch.nn as nn
from torchsummary import summary

# 1. Dataset preprocessing 

In [2]:
## GLOBAL SETTINGS

random.seed(0)
im_w_training = 256
im_h_training = 256
b_size = 8
EPOCHS = 40
num_class = 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_train_split = .95

In [3]:
## HELPER FUNCTIONS


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    return images


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp',
    '.BMP', '.tiff'
]

In [4]:
## TRAIN TEST SPLIT

org_imgs = glob(
    '../experiments/dataset/05_Final_Ground_Truth_Data/Composite_Images/*')
lable_imgs = glob(
    '../experiments/dataset/05_Final_Ground_Truth_Data/Label_Images/*')

temp = list(zip(org_imgs, lable_imgs))
random.shuffle(temp)

org_imgs, lable_imgs = zip(*temp)

split_size = int(test_train_split * len(org_imgs))

train_org_imgs, val_org_imgs = org_imgs[0:split_size], org_imgs[split_size:-1]
train_lable_imgs, val_lable_imgs = lable_imgs[0:split_size], lable_imgs[
    split_size:-1]

In [5]:
print('Total training samples --> {}'.format(len(train_org_imgs)))
print('Total validation samples --> {}'.format(len(val_org_imgs)))

Total training samples --> 1467
Total validation samples --> 77


# 2. Create dataloader

In [6]:
class RegularDataset(Dataset):
    def __init__(self, original_img_paths, original_seg_paths, augment):

        self.transforms = augment

        self.img_width = im_w_training
        self.img_height = im_h_training

        self.A_paths = original_img_paths
        self.B_paths = original_seg_paths

        self.dataset_size = len(self.A_paths)

    def __getitem__(self, index):

        # input A (RGB source)
        A_path = self.A_paths[index]
        A = Image.open(A_path)
        A = A.resize((self.img_width, self.img_height), Image.BICUBIC)
        A_tensor = self.transforms(A)

        # input B (seg map target)
        B_path = self.B_paths[index]
        B = self.parsing_embedding(B_path)  # channel(20), H, W
        B_tensor = torch.from_numpy(B)

        input_dict = {
            'source_img': A_tensor,
            'target_img': B_tensor,
            'source_img_path': A_path,
            'target_img_path': B_path
        }

        return input_dict

    def parsing_embedding(self, img_path):
        test_img = Image.open(img_path)
        test_img = test_img.resize((self.img_width, self.img_height),
                                   Image.BICUBIC)
        test_img = np.array(test_img)
        unique_channel_values = [0, 50, 100, 150, 200, 250]
        parse_emb = []

        for each_channel_val in unique_channel_values:
            parse_emb.append(
                (test_img == each_channel_val).astype(np.float32).tolist())

        parse = np.array(parse_emb).astype(np.float32)
        return parse

    def __len__(self):
        return len(self.A_paths)

    def name(self):
        return 'RegularDataset'

In [7]:
# DEFINE AUGMENTATION

augment = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ColorJitter(hue=.05, saturation=.05),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20, resample=Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])

train_dataset = RegularDataset(train_org_imgs, train_lable_imgs, augment)
val_dataset = RegularDataset(val_org_imgs, val_lable_imgs, augment)

# CREATE DATALOADER
train_dataloader = DataLoader(train_dataset,
                              shuffle=True,
                              drop_last=False,
                              num_workers=6,
                              batch_size=b_size,
                              pin_memory=True)
val_dataloader = DataLoader(val_dataset,
                            shuffle=True,
                            drop_last=False,
                            num_workers=6,
                            batch_size=b_size,
                            pin_memory=True)

# FOR DEBUGGING
print("Checking  the dimension and type of data")
for key in train_dataset[0].keys():
    try:
        x = train_dataset[0][key]
        print("name of the input and shape -- > ", key, x.shape)
        print("type,dtype,and min max -- >", type(x), x.dtype, torch.min(x),
              torch.max(x))
    except Exception as e:
        print("name of the input -- > ", key, train_dataset[0][key])
    print('----------------')

Checking  the dimension and type of data
name of the input and shape -- >  source_img torch.Size([1, 256, 256])
type,dtype,and min max -- > <class 'torch.Tensor'> torch.float32 tensor(-1.) tensor(0.6471)
----------------
name of the input and shape -- >  target_img torch.Size([6, 256, 256])
type,dtype,and min max -- > <class 'torch.Tensor'> torch.float32 tensor(0.) tensor(1.)
----------------
name of the input -- >  source_img_path ../experiments/dataset/05_Final_Ground_Truth_Data/Composite_Images/C1_0433_D5.png
----------------
name of the input -- >  target_img_path ../experiments/dataset/05_Final_Ground_Truth_Data/Label_Images/L1_0359_D4.png
----------------


# 3. UNET architecture

In [8]:


def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


base_model = models.resnet18(pretrained=True)
base_model.conv1 = torch.nn.Conv2d(1,
                                   64,
                                   kernel_size=(7, 7),
                                   stride=(2, 2),
                                   padding=(3, 3),
                                   bias=False)


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

#         self.base_model = models.resnet18(pretrained=False)
#         self.base_model.conv1 = torch.nn.Conv2d(1,
#                                                 64,
#                                                 kernel_size=(7, 7),
#                                                 stride=(2, 2),
#                                                 padding=(3, 3),
#                                                 bias=False)

        self.base_layers = list(base_model.children())

        self.layer0 = nn.Sequential(
            *self.base_layers[:3])  # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(
            *self.base_layers[3:5])  # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2,
                                    mode='bilinear',
                                    align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(1, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

In [9]:
model = ResNetUNet(num_class)
model = model.to(device)

summary(model, input_size=(1, im_w_training, im_h_training))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
            Conv2d-5         [-1, 64, 128, 128]           3,136
       BatchNorm2d-6         [-1, 64, 128, 128]             128
              ReLU-7         [-1, 64, 128, 128]               0
         MaxPool2d-8           [-1, 64, 64, 64]               0
            Conv2d-9           [-1, 64, 64, 64]          36,864
      BatchNorm2d-10           [-1, 64, 64, 64]             128
             ReLU-11           [-1, 64, 64, 64]               0
           Conv2d-12           [-1, 64, 64, 64]          36,864
      BatchNorm2d-13           [-1, 64, 64, 64]             128
             ReLU-14           [-1, 64,

# 4. DICE loss

In [10]:
def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(2).sum(2)

    loss = (1 - ((2. * intersection + smooth) /
                 (pred.sum(2).sum(2) + target.sum(2).sum(2) + smooth)))

    return loss.mean()


def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss


def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

# 5. Train the model

In [11]:
def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
                dataloader = train_dataloader
                dataloader_length = len(train_dataset) // b_size + 1
            else:
                model.eval()  # Set model to evaluate mode
                dataloader = val_dataloader
                dataloader_length = len(val_dataset) // b_size + 1
            metrics = defaultdict(float)
            epoch_samples = 0

            for i, results in tqdm(enumerate(dataloader),
                                   total=dataloader_length):
                source_img = results['source_img'].float().cuda()
                target_parse = results['target_img'].float().cuda()

                inputs = source_img
                labels = target_parse

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:

                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('total time taken to run one epoch -- {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        print(
            '---------------------------------------------------------------------------------'
        )

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [13]:
optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=0.0002,
                          betas=[0.5, 0.999])

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)

Epoch 0/39
----------
LR 0.0002


100%|██████████| 184/184 [00:47<00:00,  3.91it/s]

train: bce: 0.064168, dice: 0.579442, loss: 0.321805



100%|██████████| 10/10 [00:01<00:00,  5.13it/s]

val: bce: 0.042415, dice: 0.500260, loss: 0.271337
total time taken to run one epoch -- 0m 49s
---------------------------------------------------------------------------------
Epoch 1/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.00it/s]

train: bce: 0.043686, dice: 0.497959, loss: 0.270822



100%|██████████| 10/10 [00:01<00:00,  6.33it/s]

val: bce: 0.048580, dice: 0.504564, loss: 0.276572
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 2/39
----------
LR 0.0002



100%|██████████| 184/184 [00:47<00:00,  3.91it/s]

train: bce: 0.043626, dice: 0.493459, loss: 0.268542



100%|██████████| 10/10 [00:01<00:00,  6.47it/s]

val: bce: 0.042124, dice: 0.503413, loss: 0.272769
total time taken to run one epoch -- 0m 49s
---------------------------------------------------------------------------------
Epoch 3/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  3.95it/s]

train: bce: 0.042952, dice: 0.491678, loss: 0.267315



100%|██████████| 10/10 [00:01<00:00,  6.80it/s]

val: bce: 0.041980, dice: 0.483188, loss: 0.262584
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 4/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.04it/s]

train: bce: 0.042799, dice: 0.489507, loss: 0.266153



100%|██████████| 10/10 [00:01<00:00,  6.22it/s]

val: bce: 0.043008, dice: 0.488173, loss: 0.265590
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 5/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  3.99it/s]

train: bce: 0.042478, dice: 0.488620, loss: 0.265549



100%|██████████| 10/10 [00:01<00:00,  6.76it/s]

val: bce: 0.044109, dice: 0.483988, loss: 0.264048
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 6/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  3.97it/s]

train: bce: 0.042181, dice: 0.488157, loss: 0.265169



100%|██████████| 10/10 [00:01<00:00,  6.37it/s]

val: bce: 0.040830, dice: 0.489684, loss: 0.265257
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 7/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  3.98it/s]

train: bce: 0.042101, dice: 0.486948, loss: 0.264525



100%|██████████| 10/10 [00:01<00:00,  6.46it/s]

val: bce: 0.042289, dice: 0.482843, loss: 0.262566
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 8/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  4.00it/s]

train: bce: 0.042159, dice: 0.487195, loss: 0.264677



100%|██████████| 10/10 [00:01<00:00,  6.84it/s]

val: bce: 0.038793, dice: 0.485113, loss: 0.261953
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 9/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.05it/s]

train: bce: 0.041903, dice: 0.485594, loss: 0.263748



100%|██████████| 10/10 [00:01<00:00,  6.31it/s]

val: bce: 0.041563, dice: 0.491065, loss: 0.266314
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 10/39
----------
LR 0.0002



100%|██████████| 184/184 [00:47<00:00,  3.88it/s]

train: bce: 0.041822, dice: 0.485399, loss: 0.263611



100%|██████████| 10/10 [00:01<00:00,  6.75it/s]

val: bce: 0.040665, dice: 0.483205, loss: 0.261935
total time taken to run one epoch -- 0m 49s
---------------------------------------------------------------------------------
Epoch 11/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.04it/s]

train: bce: 0.041801, dice: 0.484764, loss: 0.263283



100%|██████████| 10/10 [00:01<00:00,  7.32it/s]

val: bce: 0.044316, dice: 0.479972, loss: 0.262144
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 12/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.01it/s]

train: bce: 0.041424, dice: 0.483521, loss: 0.262472



100%|██████████| 10/10 [00:01<00:00,  6.87it/s]

val: bce: 0.046014, dice: 0.489089, loss: 0.267552
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 13/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.06it/s]

train: bce: 0.041678, dice: 0.482809, loss: 0.262244



100%|██████████| 10/10 [00:01<00:00,  6.85it/s]

val: bce: 0.042590, dice: 0.485673, loss: 0.264132
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 14/39
----------
LR 0.0002



100%|██████████| 184/184 [00:44<00:00,  4.10it/s]

train: bce: 0.041332, dice: 0.482886, loss: 0.262109



100%|██████████| 10/10 [00:01<00:00,  7.60it/s]

val: bce: 0.043361, dice: 0.486822, loss: 0.265092
total time taken to run one epoch -- 0m 46s
---------------------------------------------------------------------------------
Epoch 15/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.06it/s]

train: bce: 0.041364, dice: 0.480906, loss: 0.261135



100%|██████████| 10/10 [00:01<00:00,  6.70it/s]

val: bce: 0.043028, dice: 0.482040, loss: 0.262534
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 16/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.07it/s]

train: bce: 0.041179, dice: 0.479864, loss: 0.260522



100%|██████████| 10/10 [00:01<00:00,  6.87it/s]

val: bce: 0.041253, dice: 0.497833, loss: 0.269543
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 17/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.07it/s]

train: bce: 0.041087, dice: 0.477645, loss: 0.259366



100%|██████████| 10/10 [00:01<00:00,  7.32it/s]

val: bce: 0.044546, dice: 0.487687, loss: 0.266116
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 18/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.05it/s]

train: bce: 0.040818, dice: 0.475312, loss: 0.258065



100%|██████████| 10/10 [00:01<00:00,  6.68it/s]

val: bce: 0.042936, dice: 0.484807, loss: 0.263871
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 19/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  3.99it/s]

train: bce: 0.040852, dice: 0.474734, loss: 0.257793



100%|██████████| 10/10 [00:01<00:00,  7.70it/s]

val: bce: 0.039295, dice: 0.484024, loss: 0.261659
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 20/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.00it/s]

train: bce: 0.040430, dice: 0.471224, loss: 0.255827



100%|██████████| 10/10 [00:01<00:00,  6.64it/s]

val: bce: 0.044911, dice: 0.484472, loss: 0.264692
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 21/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  4.00it/s]

train: bce: 0.040108, dice: 0.467644, loss: 0.253876



100%|██████████| 10/10 [00:01<00:00,  6.90it/s]

val: bce: 0.044394, dice: 0.484565, loss: 0.264480
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 22/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.06it/s]

train: bce: 0.039785, dice: 0.463938, loss: 0.251861



100%|██████████| 10/10 [00:01<00:00,  6.95it/s]

val: bce: 0.045808, dice: 0.483658, loss: 0.264733
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 23/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.07it/s]

train: bce: 0.039694, dice: 0.461173, loss: 0.250433



100%|██████████| 10/10 [00:01<00:00,  7.19it/s]

val: bce: 0.041018, dice: 0.491438, loss: 0.266228
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 24/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.02it/s]

train: bce: 0.039097, dice: 0.458378, loss: 0.248738



100%|██████████| 10/10 [00:01<00:00,  6.60it/s]

val: bce: 0.044980, dice: 0.491622, loss: 0.268301
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 25/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.04it/s]

train: bce: 0.038636, dice: 0.451725, loss: 0.245180



100%|██████████| 10/10 [00:01<00:00,  6.30it/s]

val: bce: 0.043896, dice: 0.483572, loss: 0.263734
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 26/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  4.00it/s]

train: bce: 0.038198, dice: 0.449251, loss: 0.243724



100%|██████████| 10/10 [00:01<00:00,  7.07it/s]

val: bce: 0.047548, dice: 0.484101, loss: 0.265825
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 27/39
----------
LR 0.0002



100%|██████████| 184/184 [00:45<00:00,  4.04it/s]

train: bce: 0.037712, dice: 0.444308, loss: 0.241010



100%|██████████| 10/10 [00:01<00:00,  6.58it/s]

val: bce: 0.045489, dice: 0.493160, loss: 0.269325
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 28/39
----------
LR 0.0002



100%|██████████| 184/184 [00:46<00:00,  3.97it/s]

train: bce: 0.037441, dice: 0.442584, loss: 0.240012



100%|██████████| 10/10 [00:01<00:00,  6.81it/s]

val: bce: 0.049907, dice: 0.511886, loss: 0.280897
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 29/39
----------
LR 2e-05



100%|██████████| 184/184 [00:45<00:00,  4.05it/s]

train: bce: 0.035489, dice: 0.429917, loss: 0.232703



100%|██████████| 10/10 [00:01<00:00,  6.63it/s]

val: bce: 0.047109, dice: 0.493528, loss: 0.270319
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 30/39
----------
LR 2e-05



100%|██████████| 184/184 [00:46<00:00,  3.95it/s]

train: bce: 0.035127, dice: 0.424075, loss: 0.229601



100%|██████████| 10/10 [00:01<00:00,  7.07it/s]

val: bce: 0.047283, dice: 0.494656, loss: 0.270970
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 31/39
----------
LR 2e-05



100%|██████████| 184/184 [00:46<00:00,  3.94it/s]

train: bce: 0.034834, dice: 0.421590, loss: 0.228212



100%|██████████| 10/10 [00:01<00:00,  6.99it/s]

val: bce: 0.046930, dice: 0.489291, loss: 0.268111
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 32/39
----------
LR 2e-05



100%|██████████| 184/184 [00:48<00:00,  3.79it/s]

train: bce: 0.034546, dice: 0.418491, loss: 0.226518



100%|██████████| 10/10 [00:01<00:00,  6.20it/s]

val: bce: 0.046820, dice: 0.498151, loss: 0.272485
total time taken to run one epoch -- 0m 50s
---------------------------------------------------------------------------------
Epoch 33/39
----------
LR 2e-05



100%|██████████| 184/184 [00:46<00:00,  3.95it/s]

train: bce: 0.034438, dice: 0.417927, loss: 0.226182



100%|██████████| 10/10 [00:01<00:00,  6.38it/s]

val: bce: 0.047326, dice: 0.490440, loss: 0.268883
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 34/39
----------
LR 2e-05



100%|██████████| 184/184 [00:46<00:00,  4.00it/s]

train: bce: 0.034007, dice: 0.414876, loss: 0.224442



100%|██████████| 10/10 [00:01<00:00,  6.28it/s]

val: bce: 0.049320, dice: 0.502644, loss: 0.275982
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 35/39
----------
LR 2e-05



100%|██████████| 184/184 [00:46<00:00,  3.97it/s]

train: bce: 0.033953, dice: 0.413252, loss: 0.223602



100%|██████████| 10/10 [00:01<00:00,  7.30it/s]

val: bce: 0.047726, dice: 0.494977, loss: 0.271352
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 36/39
----------
LR 2e-05



100%|██████████| 184/184 [00:45<00:00,  4.02it/s]

train: bce: 0.033409, dice: 0.410373, loss: 0.221891



100%|██████████| 10/10 [00:01<00:00,  6.69it/s]

val: bce: 0.048495, dice: 0.496157, loss: 0.272326
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 37/39
----------
LR 2e-05



100%|██████████| 184/184 [00:46<00:00,  3.96it/s]

train: bce: 0.033548, dice: 0.411343, loss: 0.222445



100%|██████████| 10/10 [00:01<00:00,  6.60it/s]

val: bce: 0.047792, dice: 0.485703, loss: 0.266747
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Epoch 38/39
----------
LR 2e-05



100%|██████████| 184/184 [00:45<00:00,  4.04it/s]

train: bce: 0.033493, dice: 0.408311, loss: 0.220902



100%|██████████| 10/10 [00:01<00:00,  6.93it/s]

val: bce: 0.049409, dice: 0.503025, loss: 0.276217
total time taken to run one epoch -- 0m 47s
---------------------------------------------------------------------------------
Epoch 39/39
----------
LR 2e-05



100%|██████████| 184/184 [00:46<00:00,  3.94it/s]

train: bce: 0.033189, dice: 0.408124, loss: 0.220657



100%|██████████| 10/10 [00:01<00:00,  6.62it/s]

val: bce: 0.049381, dice: 0.498474, loss: 0.273928
total time taken to run one epoch -- 0m 48s
---------------------------------------------------------------------------------
Best val loss: 0.261659





In [16]:
torch.save(model.state_dict(), './model_final_dict.pt')

In [17]:
torch.save(model,'model_final.pt')