In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# Fetching necessary files

!wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Training_Input.zip
!wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Training_GroundTruth.zip
!wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Validation_Input.zip
!wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Validation_GroundTruth.zip

!unzip ./ISIC2018_Task1-2_Training_Input.zip
# !rm ./ISIC2018_Task1-2_Training_Input.zip
!unzip ./ISIC2018_Task1_Training_GroundTruth.zip
# !rm ./ISIC2018_Task1_Training_GroundTruth.zip
!unzip ./ISIC2018_Task1-2_Validation_Input.zip
# !rm ./ISIC2018_Task1-2_Validation_Input.zip
!unzip ./ISIC2018_Task1_Validation_GroundTruth.zip
# !rm ./ISIC2018_Task1_Validation_GroundTruth.zip
!mkdir ./Preproc

In [2]:
import os
from tqdm import tqdm
import torchinfo
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split
import torchvision
import cv2 as cv
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from glob import glob
from torchvision.io import read_image
# import scipy as sc

In [3]:
os.chdir('E:\\Segmentation')

TRAIN_INPUT_DIR = 'ISIC2018_Task1-2_Training_Input/'
TRAIN_GT_DIR = 'ISIC2018_Task1_Training_GroundTruth/'
TRAIN_INTERM_DIR = 'ISIC2017-Training-Interm'

VAL_INPUT_DIR = 'ISIC2018_Task1-2_Validation_Input/'
VAL_GT_DIR = 'ISIC2018_Task1_Validation_GroundTruth/'
VAL_INTERM_DIR = 'ISIC2017-Validation-Output'

BATCH_SIZE = 4
LEARNING_RATE = 1e-4
EPOCHS = 100
IM_H, IM_W = 256, 256
TRAINING_NOISE = 0
DROPOUT = .5

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
print(DEVICE)

cuda


In [5]:
def centroid(img, lcc=False):
  if lcc:
    img = img.astype(np.uint8)
    nb_components, output, stats, centroids = cv.connectedComponentsWithStats(img, connectivity=4)
    sizes = stats[:, -1]
    if len(sizes) > 2:
      max_label = 1
      max_size = sizes[1]

      for i in range(2, nb_components):
          if sizes[i] > max_size:
              max_label = i
              max_size = sizes[i]

      img2 = np.zeros(output.shape)
      img2[output == max_label] = 255
      img = img2

  if len(img.shape) > 2:
    M = cv.moments(img[:,:,1])
  else:
    M = cv.moments(img)

  if M["m00"] == 0:
    return (img.shape[0] // 2, img.shape[1] // 2)
  
  cX = int(M["m10"] / M["m00"])
  cY = int(M["m01"] / M["m00"])
  return (cX, cY)

def to_polar(input_img, center):
  input_img = input_img.astype(np.float32)
  value = np.sqrt(((input_img.shape[0]/2.0)**2.0)+((input_img.shape[1]/2.0)**2.0))
  polar_image = cv.linearPolar(input_img, center, value, cv.WARP_FILL_OUTLIERS)
  return polar_image

def to_cart(input_img, center):
  input_img = input_img.astype(np.float32)
  value = np.sqrt(((input_img.shape[1]/2.0)**2.0)+((input_img.shape[0]/2.0)**2.0))
  polar_image = cv.linearPolar(input_img, center, value, cv.WARP_FILL_OUTLIERS + cv.WARP_INVERSE_MAP)
  polar_image = polar_image.astype(np.uint8)
  return polar_image

def calc_dice(input_img, target):
    tp = np.sum(np.minimum(input_img, target))
    fp = np.sum(np.minimum(input_img, 1 - target))
    fn = np.sum(np.minimum(1 - input_img, target))
    return 2 * tp / (2 * tp + fp + fn)

In [6]:
class ISICDataset(torch.utils.data.Dataset):
    def __init__(self, input_folder, seg_folder, transform=None, target_transform=None):
        self.in_files = glob(os.path.join(input_folder, '*.jpg'))
        self.gt_files = glob(os.path.join(seg_folder, '*.png'))
        df = pd.read_csv('centers_new.csv')
        self.centers = dict()
        for item in df['file']:
            self.centers[item] = df[df['file'] == item].to_numpy()[0,1:3].tolist()
        self.tfy = torchvision.transforms.Resize((384, 512), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
        self.tfX = torchvision.transforms.Resize((IM_W, IM_H), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
        print(len(self.in_files), len(self.gt_files))
        

    def __len__(self):
        return len(self.in_files)

    def __getitem__(self, idx):
        in_img = self.tfX(torchvision.io.read_image(self.in_files[idx], torchvision.io.image.ImageReadMode.RGB)).float()
        in_img = (in_img - torch.mean(in_img, dim=(1,2),keepdim=True)) / torch.std(in_img, dim=(1,2), keepdim=True)
        gt_img = self.tfy(torchvision.io.read_image(self.gt_files[idx], torchvision.io.image.ImageReadMode.GRAY)).float()
        gt_img = gt_img / torch.max(gt_img)
        # gt_img = torch.unsqueeze(gt_img, 0)
        X = in_img
        filename = os.path.split(self.in_files[idx])[-1][:-4]
        center = self.centers[self.in_files[idx]]        
        y = gt_img
        mask_l = torch.ones(1, 384, 512, dtype=torch.float32)
        mask_s = torch.ones(1, IM_H, IM_W, dtype=torch.float32)
        return (X, mask_l, mask_s, y, filename, center)

In [7]:
class ISICDatasetTrain(torch.utils.data.Dataset):
    def __init__(self, src):
        self.src = src
        self.tf = torchvision.transforms.Resize((IM_W, IM_H))

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        X, mask_l, mask_s, y, filename, center = self.src[idx]
        y = self.tf(y)
        y = torch.round(y)
        return (X, mask_l, mask_s, y, filename, center)

In [8]:
class ISICDatasetTest(torch.utils.data.Dataset):
    def __init__(self, src):
        self.src = src

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        X, mask_l, mask_s, y, filename, center = self.src[idx]
        y = torch.round(y)
        return (X, mask_l, mask_s, y, filename, center)

In [9]:
full_ds = ISICDataset(TRAIN_INPUT_DIR, TRAIN_GT_DIR)
n, nt = len(full_ds), int(len(full_ds) / 10)
train_ds_p, valid_ds_p, test_ds_p = random_split(full_ds, [n - 2 * nt, nt, nt])
train_ds = ISICDatasetTrain(train_ds_p)
valid_ds = ISICDatasetTrain(valid_ds_p)
test_ds = ISICDatasetTest(test_ds_p)

2594 2594


In [10]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False, pin_memory=False)

In [11]:
# with mask

def train(model):
    num_batches = len(train_dl)
    model.train()
    losses = []
    accuracies = []
    dices = []
    with tqdm(total=num_batches) as pbar:
        pbar.set_description("Avg.Loss: 0.0000, Avg. Accuracy: 0.0000")
        for (X, mask_l, mask_s, y, filename, center) in train_dl:
            loss, met = model.fit(X, mask_s, y)
            losses.append(loss)
            accuracies.append((met[0] + met[3]) / (met[0] + met[1] + met[2] + met[3]))
            dices.append(met[0] * 2 / (met[0] * 2 + met[1] + met[2]))
            pbar.update(1)
            pbar.set_description(f"Avg. Loss: {torch.mean(torch.tensor(losses)):.4f}, Avg. Accuracy: {torch.mean(torch.tensor(accuracies)):.4f}, Avg. Dice: {torch.mean(torch.tensor(dices)):.4f}")
    return torch.mean(torch.tensor(losses)).item(), torch.mean(torch.tensor(dices)).item()

def test(model):
    num_batches = len(valid_dl)
    model.eval()
    losses = []
    accuracies = []
    dices = []
    with tqdm(total=num_batches) as pbar:
        pbar.set_description("Avg.Loss: 0.0000, Avg. Accuracy: 0.0000")
        with torch.no_grad():
            for (X, mask_l, mask_s, y, filename, center) in valid_dl:
                loss, met = model.test(X, mask_s, y)
                losses.append(loss)
                accuracies.append((met[0] + met[3]) / (met[0] + met[1] + met[2] + met[3]))
                dices.append(met[0] * 2 / (met[0] * 2 + met[1] + met[2]))
                pbar.update(1)
                pbar.set_description(f"Avg. Loss: {torch.mean(torch.tensor(losses)):.4f}, Avg. Accuracy: {torch.mean(torch.tensor(accuracies)):.4f}, Avg. Dice: {torch.mean(torch.tensor(dices)):.4f}")
    return torch.mean(torch.tensor(losses)).item(), torch.mean(torch.tensor(dices)).item()

# Loss functions

In [32]:
def iou_loss(pred, mask):
    weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    
    pred = torch.sigmoid(pred)
    inter = ((pred * mask) * weit).sum(dim=(2, 3))
    union = ((pred + mask) * weit).sum(dim=(2, 3))
    iou = 1 - (inter + 1) / (union - inter + 1)

    return iou.mean()


def structure_loss(pred, mask):
    weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask) * weit).sum(dim=(2, 3))
    union = ((pred + mask) * weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1) / (union - inter + 1)

    return (wbce + wiou).mean()

def dice_loss(pred, mask, cover):
    weit = 1 # + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=7, stride=1, padding=3) - mask)
    
    pred = torch.sigmoid(pred)
    tp = ((pred * mask) * weit).sum(dim=(2, 3))
    fp = ((pred * (1 - mask)) * weit).sum(dim=(2, 3))
    fn = (((1 - pred) * mask) * weit).sum(dim=(2, 3))
    dice = 1 - (2 * tp + 1) / (2 * tp + fp + fn + 1)

    return dice.mean()

class FocalLoss():
    def __init__(self, gamma):
        self.gamma = gamma
        
    def __call__(self, pred, mask):
        pred = torch.sigmoid(pred)
        dist = torch.abs(pred - mask)
        bce = -torch.log((1 - dist) + 1e-15)
        focal_adjusted = torch.mean(dist**self.gamma * bce)
        return focal_adjusted

def focal_struct_loss(pred, mask):
    weit = torch.abs(mask - F.sigmoid(pred))
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask) * weit).sum(dim=(2, 3))
    union = ((pred + mask) * weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1) / (union - inter + 1)

    return (wbce + wiou).mean()

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from lib.pvtv2 import pvt_v2_b2
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# from mmcv.cnn import ConvModule
from torch.nn import Conv2d, UpsamplingBilinear2d
import warnings
import torch
# from mmcv.cnn import constant_init, kaiming_init
from torch import nn
from torchvision.transforms.functional import normalize
from PVTv2.pvtv2 import pvt_v2_b2
warnings.filterwarnings('ignore')

class SpatBlock(nn.Module):
    def __init__(self, dim=32, order=5):
        super().__init__()
        self.dim = dim
        self.order = order
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.blocks = nn.ModuleList([nn.Conv2d(1, dim, 1)] + [nn.Conv2d(1, dim, 1, bias=False) for _ in range(order-1)])

        self = self.to(self.device)

    def forward(self, x):
        B, C, H, W = x.shape

        pos = torch.arange(W) / W
        pos = torch.reshape(pos, (1, 1, 1, W)).to(self.device)
        
        func = torch.zeros_like(x)

        for i, block in enumerate(self.blocks):
            func += block(pos**(i+1))

        out = x + func
        
        return out
        

def structure_loss(pred, mask):
    weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask) * weit).sum(dim=(2, 3))
    union = ((pred + mask) * weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1) / (union - inter + 1)

    return (wbce + wiou).mean()

# def structure_loss(pred, mask):
#     pred = torch.sigmoid(pred)
#     loss1 = F.binary_cross_entropy(pred, mask)
#     tp = pred * mask
#     fp = pred * (1 - mask)
#     fn = (1 - pred) * mask
#     tp = torch.mean(tp, dim=0)
#     fp = torch.mean(fp, dim=0)
#     fn = torch.mean(fn, dim=0)
#     iou = tp / (tp + fp + fn)
#     iou = torch.mean(iou)
#     loss2 = 1 - iou
#     return loss1 + loss2

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()

        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
        
class Block(nn.Sequential):
    def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True, norm_layer=nn.BatchNorm2d):
        super(Block, self).__init__()
        if bn_start:
            self.add_module('norm1', norm_layer(input_num)),

        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)),

        self.add_module('norm2', norm_layer(num1)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3,
                                            dilation=dilation_rate, padding=dilation_rate)),
        self.drop_rate = drop_out

    def forward(self, _input):
        feature = super(Block, self).forward(_input)
        if self.drop_rate > 0:
            feature = F.dropout2d(feature, p=self.drop_rate, training=self.training)
        return feature


def Upsample(x, size, align_corners = False):
    """
    Wrapper Around the Upsample Call
    """
    return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners)

    
# def last_zero_init(m):
#     if isinstance(m, nn.Sequential):
#         constant_init(m[-1], val=0)
#     else:
#         constant_init(m, val=0)


class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_mul', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out + out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out



class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class ConvBranch(nn.Module):
    def __init__(self, in_features, hidden_features = None, out_features = None):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.SiLU(inplace=True)
        )
        self.conv6 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv7 = nn.Sequential(
            nn.Conv2d(hidden_features, out_features, 1, bias=False),
            nn.ReLU(inplace=True)
        )
        self.ca = ChannelAttention(64)
        self.sa = SpatialAttention()
        self.sigmoid_spatial = nn.Sigmoid()
    
    def forward(self, x):
        res1 = x
        res2 = x
        x = self.conv1(x)        
        x = x + self.conv2(x)
        x = self.conv3(x)
        x = x + self.conv4(x)
        x = self.conv5(x)
        x = x + self.conv6(x)
        x = self.conv7(x)
        x_mask = self.sigmoid_spatial(x)
        res1 = res1 * x_mask
        return res2 + res1

              
class GLSA(nn.Module):

    def __init__(self, input_dim=512, embed_dim=32, k_s=3):
        super().__init__()
                      
        self.conv1_1 = BasicConv2d(embed_dim*2,embed_dim, 1)
        self.conv1_1_1 = BasicConv2d(input_dim//2,embed_dim,1)
        self.local_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
        self.global_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
        self.GlobelBlock = ContextBlock(inplanes= embed_dim, ratio=2)
        self.local = ConvBranch(in_features = embed_dim, hidden_features = embed_dim, out_features = embed_dim)

    def forward(self, x):
        b, c, h, w = x.size()
        x_0, x_1 = x.chunk(2,dim = 1)  
        
    # local block 
        local = self.local(self.local_11conv(x_0))
        
    # Globel block    
        Globel = self.GlobelBlock(self.global_11conv(x_1))

    # concat Globel + local
        x = torch.cat([local,Globel], dim=1)
        x = self.conv1_1(x)

        return x    

class SBA(nn.Module):

    def __init__(self,input_dim = 64):
        super().__init__()

        self.input_dim = input_dim

        self.d_in1 = BasicConv2d(input_dim//2, input_dim//2, 1)
        self.d_in2 = BasicConv2d(input_dim//2, input_dim//2, 1)       
                

        self.conv = nn.Sequential(BasicConv2d(input_dim, input_dim, 3,1,1), nn.Conv2d(input_dim, 1, kernel_size=1, bias=False))
        self.fc1 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
        self.fc2 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
        
        self.Sigmoid = nn.Sigmoid()
        
    def forward(self, H_feature, L_feature):

        L_feature = self.fc1(L_feature)
        H_feature = self.fc2(H_feature)
        
        g_L_feature =  self.Sigmoid(L_feature)
        g_H_feature = self.Sigmoid(H_feature)
        
        L_feature = self.d_in1(L_feature)
        H_feature = self.d_in2(H_feature)


        L_feature = L_feature + L_feature * g_L_feature + (1 - g_L_feature) * Upsample(g_H_feature * H_feature, size= L_feature.size()[2:], align_corners=False)
        H_feature = H_feature + H_feature * g_H_feature + (1 - g_H_feature) * Upsample(g_L_feature * L_feature, size= H_feature.size()[2:], align_corners=False) 
        
        H_feature = Upsample(H_feature, size = L_feature.size()[2:])
        out = self.conv(torch.cat([H_feature,L_feature], dim=1))
        return out
        
            
class DuAT(nn.Module):
    def __init__(self, dim=32, dims= [64, 256, 512, 1024], learning_rate=None,  loss_fn=None, optimizer=None, device=None, weight_decay=None):
        super(DuAT, self).__init__()

                    
        if (device is None):
            self.device = DEVICE
        else:
            self.device = device

        if (learning_rate is None):
            self.learning_rate = 1e-5
        else:
            self.learning_rate = learning_rate
        
        if (weight_decay is None):
            self.weight_decay = 1e-5
        else:
            self.weight_decay = weight_decay

        self.dims = dims

        # self.pad = False
        resnet50 = torchvision.models.resnet50(weights="DEFAULT")
        self.e1 = nn.Sequential(resnet50._modules['conv1'], resnet50._modules['bn1'], resnet50._modules['relu'])
        self.e2 = nn.Sequential(resnet50._modules['maxpool'], resnet50._modules['layer1'])
        self.e3 = nn.Sequential(resnet50._modules['layer2'])
        self.e4 = nn.Sequential(resnet50._modules['layer3'])
        
        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3]

        self.shuffle4 = nn.Conv2d(c4_in_channels, c4_in_channels, 1)
        self.shuffle3 = nn.Conv2d(c3_in_channels, c3_in_channels, 1)
        self.shuffle2 = nn.Conv2d(c2_in_channels, c2_in_channels, 1)
        self.shuffle1 = nn.Conv2d(c1_in_channels, c1_in_channels, 1)

        self.bn4 = nn.BatchNorm2d(c4_in_channels)
        self.bn3 = nn.BatchNorm2d(c3_in_channels)
        self.bn2 = nn.BatchNorm2d(c2_in_channels)
        self.bn1 = nn.BatchNorm2d(c1_in_channels)

        self.embed4 = SpatBlock(dim=c4_in_channels)
        self.embed3 = SpatBlock(dim=c3_in_channels)
        self.embed2 = SpatBlock(dim=c2_in_channels)
        self.embed1 = SpatBlock(dim=c1_in_channels)

        self.bn4 = nn.BatchNorm2d(c4_in_channels)
        self.bn3 = nn.BatchNorm2d(c3_in_channels)
        self.bn2 = nn.BatchNorm2d(c2_in_channels)
        self.bn1 = nn.BatchNorm2d(c1_in_channels)

        self.GLSA_c4 = GLSA(input_dim=c4_in_channels, embed_dim=dim)
        self.GLSA_c3 = GLSA(input_dim=c3_in_channels, embed_dim=dim)
        self.GLSA_c2 = GLSA(input_dim=c2_in_channels, embed_dim=dim)
        self.L_feature = BasicConv2d(c1_in_channels, dim, 3,1,1)
        
        self.SBA = SBA(input_dim = dim)
        self.fuse = BasicConv2d(dim * 2, dim, 1)
        self.fuse2 = nn.Sequential(BasicConv2d(dim*3, dim, 1,1),nn.BatchNorm2d(dim),nn.Conv2d(dim, 1, kernel_size=1, bias=False))


        for param in self.e1.parameters():
            param.requires_grad_(False)
            
        for param in self.e2.parameters():
            param.requires_grad_(False)
            
        for param in self.e3.parameters():
            param.requires_grad_(False)
            
        for param in self.e4.parameters():
            param.requires_grad_(False)

        if (loss_fn is None):
            self.loss_fn = structure_loss
        else:
            self.loss_fn = loss_fn

        if (optimizer is None):
            self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        else:
            self.optimizer = optimizer(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

        self = self.to(self.device)

    def forward(self, x, mask):
        # backbone
        x = x.to(self.device)
        c1 = self.e1(x)
        c2 = self.e2(c1)
        c3 = self.e3(c2)
        c4 = self.e4(c3)


        
        # c1 = self.bn1(self.shuffle1(c1))
        # c2 = self.bn2(self.shuffle2(c2))
        # c3 = self.bn3(self.shuffle3(c3))
        # c4 = self.bn4(self.shuffle4(c4))
        
        # c1 = torch.cat([c1,
        #                 torch.arange(c1.shape[2], dtype=c1.dtype, device=self.device).view(1,1,-1,1).repeat(c1.shape[0], 1, 1, c1.shape[3]) / c1.shape[2],
        #                 torch.arange(c1.shape[3], dtype=c1.dtype, device=self.device).view(1,1,1,-1).repeat(c1.shape[0], 1, c1.shape[2], 1) / c1.shape[3]], dim=1)
        # c2 = torch.cat([c2,
        #                 torch.arange(c2.shape[2], dtype=c2.dtype, device=self.device).view(1,1,-1,1).repeat(c2.shape[0], 1, 1, c2.shape[3]) / c2.shape[2],
        #                 torch.arange(c2.shape[3], dtype=c2.dtype, device=self.device).view(1,1,1,-1).repeat(c2.shape[0], 1, c2.shape[2], 1) / c2.shape[3]], dim=1)
        # c3 = torch.cat([c3,
        #                 torch.arange(c3.shape[2], dtype=c3.dtype, device=self.device).view(1,1,-1,1).repeat(c1.shape[0], 1, 1, c3.shape[3]) / c3.shape[2],
        #                 torch.arange(c3.shape[3], dtype=c3.dtype, device=self.device).view(1,1,1,-1).repeat(c1.shape[0], 1, c3.shape[2], 1) / c3.shape[3]], dim=1)
        # c4 = torch.cat([c4,
        #                 torch.arange(c4.shape[2], dtype=c4.dtype, device=self.device).view(1,1,-1,1).repeat(c4.shape[0], 1, 1, c4.shape[3]) / c4.shape[2],
        #                 torch.arange(c4.shape[3], dtype=c4.dtype, device=self.device).view(1,1,1,-1).repeat(c4.shape[0], 1, c4.shape[2], 1) / c4.shape[3]], dim=1)

        c1 = self.bn1(self.embed1(c1))
        c2 = self.bn2(self.embed2(c2))
        c3 = self.bn3(self.embed3(c3))
        c4 = self.bn4(self.embed4(c4))
        
        n, _, h, w = c4.shape        
        _c4 = self.GLSA_c4(c4) # [1, 64, 11, 11]
        _c4 = Upsample(_c4, c3.size()[2:])
        _c3 = self.GLSA_c3(c3) # [1, 64, 22, 22]
        _c2 = self.GLSA_c2(c2) # [1, 64, 44, 44]
        
        output = torch.cat([Upsample(_c4, c2.size()[2:]), Upsample(_c3, c2.size()[2:]), _c2], dim=1)
        
        L_feature = self.L_feature(c1)  # [1, 64, 88, 88]
        H_feature = self.fuse(torch.cat([_c4, _c3], dim=1))
        H_feature = Upsample(H_feature,c2.size()[2:])
        
        output2 = self.SBA(H_feature,L_feature)
        h = x.shape[2] // 4
        output = F.interpolate(output, scale_factor=4, mode='bicubic')
        # output = torch.cat([output, Upsample(c1, output.size()[2:])], dim=1)
        output = self.fuse2(output)
        # return torch.sigmoid(output)
        output2 = F.interpolate(output2, scale_factor=2, mode='bicubic')

        
        return output, output2
        # return F.sigmoid(output[:,:,h:-h,:])
        
    def fit(self, X, mask, y):
        X, mask, y = X.to(self.device), mask.to(self.device), y.to(self.device)
        h1, h2 = self(X, mask)
        self.optimizer.zero_grad()
        loss = self.loss_fn(h1, y) + self.loss_fn(h2, y)
        loss.backward()
        self.optimizer.step()
        loss = loss.item()
        pred = (h2 > 0).int()
        numt = torch.sum(mask)
        TP = torch.sum(torch.minimum(y, pred)).item() / numt
        TN = torch.sum(torch.minimum(1-y, 1-pred) * mask).item() / numt
        FN = torch.sum(torch.minimum(y, 1-pred)).item() / numt
        FP = torch.sum(torch.minimum(1-y, pred)).item() / numt
        return (loss, torch.tensor([TP, FP, FN, TN]))

    def test(self, X, mask, y):
        X, mask, y = X.to(self.device), mask.to(self.device), y.to(self.device)
        h1, h2 = self(X, mask)
        loss = self.loss_fn(h1, y) + self.loss_fn(h2, y)
        loss = loss.item()
        pred = (h2 > 0).int()
        numt = torch.sum(mask)
        TP = torch.sum(torch.minimum(y, pred)).item() / numt
        TN = torch.sum(torch.minimum(1-y, 1-pred) * mask).item() / numt
        FN = torch.sum(torch.minimum(y, 1-pred)).item() / numt
        FP = torch.sum(torch.minimum(1-y, pred)).item() / numt
        return (loss, torch.tensor([TP, FP, FN, TN]))


# PVT DuAT

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from lib.pvtv2 import pvt_v2_b2
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# from mmcv.cnn import ConvModule
from torch.nn import Conv2d, UpsamplingBilinear2d
import warnings
import torch
# from mmcv.cnn import constant_init, kaiming_init
from torch import nn
from torchvision.transforms.functional import normalize
from BaseModel import BaseModel
from PVT.segmentation.pvt import pvt_large, pvt_medium, pvt_small, pvt_tiny
from PVTv2.pvtv2 import *
warnings.filterwarnings('ignore')

def structure_loss(pred, mask):
    weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask) * weit).sum(dim=(2, 3))
    union = ((pred + mask) * weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1) / (union - inter + 1)

    return (wbce + wiou).mean()

# def structure_loss(pred, mask):
#     pred = torch.sigmoid(pred)
#     loss1 = F.binary_cross_entropy(pred, mask)
#     tp = pred * mask
#     fp = pred * (1 - mask)
#     fn = (1 - pred) * mask
#     tp = torch.mean(tp, dim=0)
#     fp = torch.mean(fp, dim=0)
#     fn = torch.mean(fn, dim=0)
#     iou = tp / (tp + fp + fn)
#     iou = torch.mean(iou)
#     loss2 = 1 - iou
#     return loss1 + loss2

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()

        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
        
class Block(nn.Sequential):
    def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True, norm_layer=nn.BatchNorm2d):
        super(Block, self).__init__()
        if bn_start:
            self.add_module('norm1', norm_layer(input_num)),

        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)),

        self.add_module('norm2', norm_layer(num1)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3,
                                            dilation=dilation_rate, padding=dilation_rate)),
        self.drop_rate = drop_out

    def forward(self, _input):
        feature = super(Block, self).forward(_input)
        if self.drop_rate > 0:
            feature = F.dropout2d(feature, p=self.drop_rate, training=self.training)
        return feature


def Upsample(x, size, align_corners = False):
    """
    Wrapper Around the Upsample Call
    """
    return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners)

    
# def last_zero_init(m):
#     if isinstance(m, nn.Sequential):
#         constant_init(m[-1], val=0)
#     else:
#         constant_init(m, val=0)


class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_mul', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out + out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out



class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class ConvBranch(nn.Module):
    def __init__(self, in_features, hidden_features = None, out_features = None):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.SiLU(inplace=True)
        )
        self.conv6 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv7 = nn.Sequential(
            nn.Conv2d(hidden_features, out_features, 1, bias=False),
            nn.ReLU(inplace=True)
        )
        self.ca = ChannelAttention(64)
        self.sa = SpatialAttention()
        self.sigmoid_spatial = nn.Sigmoid()
    
    def forward(self, x):
        res1 = x
        res2 = x
        x = self.conv1(x)        
        x = x + self.conv2(x)
        x = self.conv3(x)
        x = x + self.conv4(x)
        x = self.conv5(x)
        x = x + self.conv6(x)
        x = self.conv7(x)
        x_mask = self.sigmoid_spatial(x)
        res1 = res1 * x_mask
        return res2 + res1

              
class GLSA(nn.Module):

    def __init__(self, input_dim=512, embed_dim=32, k_s=3):
        super().__init__()
                      
        self.conv1_1 = BasicConv2d(embed_dim*2,embed_dim, 1)
        self.conv1_1_1 = BasicConv2d(input_dim//2,embed_dim,1)
        self.local_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
        self.global_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
        self.GlobelBlock = ContextBlock(inplanes= embed_dim, ratio=2)
        self.local = ConvBranch(in_features = embed_dim, hidden_features = embed_dim, out_features = embed_dim)

    def forward(self, x):
        b, c, h, w = x.size()
        x_0, x_1 = x.chunk(2,dim = 1)  
        
    # local block 
        local = self.local(self.local_11conv(x_0))
        
    # Globel block    
        Globel = self.GlobelBlock(self.global_11conv(x_1))

    # concat Globel + local
        x = torch.cat([local,Globel], dim=1)
        x = self.conv1_1(x)

        return x    

class SBA(nn.Module):

    def __init__(self,input_dim = 64):
        super().__init__()

        self.input_dim = input_dim

        self.d_in1 = BasicConv2d(input_dim//2, input_dim//2, 1)
        self.d_in2 = BasicConv2d(input_dim//2, input_dim//2, 1)       
                

        self.conv = nn.Sequential(BasicConv2d(input_dim, input_dim, 3,1,1), nn.Conv2d(input_dim, 1, kernel_size=1, bias=False))
        self.fc1 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
        self.fc2 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
        
        self.Sigmoid = nn.Sigmoid()
        
    def forward(self, H_feature, L_feature):

        L_feature = self.fc1(L_feature)
        H_feature = self.fc2(H_feature)
        
        g_L_feature =  self.Sigmoid(L_feature)
        g_H_feature = self.Sigmoid(H_feature)
        
        L_feature = self.d_in1(L_feature)
        H_feature = self.d_in2(H_feature)


        L_feature = L_feature + L_feature * g_L_feature + (1 - g_L_feature) * Upsample(g_H_feature * H_feature, size= L_feature.size()[2:], align_corners=False)
        H_feature = H_feature + H_feature * g_H_feature + (1 - g_H_feature) * Upsample(g_L_feature * L_feature, size= H_feature.size()[2:], align_corners=False) 
        
        H_feature = Upsample(H_feature, size = L_feature.size()[2:])
        out = self.conv(torch.cat([H_feature,L_feature], dim=1))
        return out
        
            
class DuAT(nn.Module):
    def __init__(self, dim=32, dims= [64, 128, 320, 512], learning_rate=None,  loss_fn=None, optimizer=None, device=None, weight_decay=None):
        super(DuAT, self).__init__()

                    
        if (device is None):
            self.device = DEVICE
        else:
            self.device = device

        if (learning_rate is None):
            self.learning_rate = 1e-5
        else:
            self.learning_rate = learning_rate
        
        if (weight_decay is None):
            self.weight_decay = 1e-5
        else:
            self.weight_decay = weight_decay

        self.dims = dims

        self.pad = False
        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        
        
        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3]
        

        self.embed4 = nn.Embedding(48, c4_in_channels)
        self.embed3 = nn.Embedding(96, c3_in_channels)
        self.embed2 = nn.Embedding(192, c2_in_channels)
        self.embed1 = nn.Embedding(384, c1_in_channels)

        self.bn4 = nn.BatchNorm2d(c4_in_channels)
        self.bn3 = nn.BatchNorm2d(c3_in_channels)
        self.bn2 = nn.BatchNorm2d(c2_in_channels)
        self.bn1 = nn.BatchNorm2d(c1_in_channels)

        self.GLSA_c4 = GLSA(input_dim=c4_in_channels, embed_dim=dim)
        self.GLSA_c3 = GLSA(input_dim=c3_in_channels, embed_dim=dim)
        self.GLSA_c2 = GLSA(input_dim=c2_in_channels, embed_dim=dim)
        self.L_feature = BasicConv2d(c1_in_channels, dim, 3,1,1)
        
        self.SBA = SBA(input_dim = dim)
        self.fuse = BasicConv2d(dim * 2, dim, 1)
        self.fuse2 = nn.Sequential(BasicConv2d(dim*3, dim, 1,1),nn.BatchNorm2d(dim),nn.Conv2d(dim, 1, kernel_size=1, bias=False))

        

        if (loss_fn is None):
            self.loss_fn = structure_loss
        else:
            self.loss_fn = loss_fn

        if (optimizer is None):
            self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        else:
            self.optimizer = optimizer(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

        self = self.to(self.device)

    def forward(self, x, mask):
        # backbone
        x = x.to(self.device)
        c1, c2, c3, c4 = self.backbone(x)
        
        n, _, h, w = c4.shape        
        _c4 = self.GLSA_c4(c4) # [1, 64, 11, 11]
        _c4 = Upsample(_c4, c3.size()[2:])
        _c3 = self.GLSA_c3(c3) # [1, 64, 22, 22]
        _c2 = self.GLSA_c2(c2) # [1, 64, 44, 44]
        
        output = torch.cat([Upsample(_c4, c2.size()[2:]), Upsample(_c3, c2.size()[2:]), _c2], dim=1)
        
        L_feature = self.L_feature(c1)  # [1, 64, 88, 88]
        H_feature = self.fuse(torch.cat([_c4, _c3], dim=1))
        H_feature = Upsample(H_feature,c2.size()[2:])
        
        output2 = self.SBA(H_feature,L_feature)
        h = x.shape[2] // 4
        output = F.interpolate(output, scale_factor=8, mode='bilinear')
        # output = torch.cat([output, Upsample(c1, output.size()[2:])], dim=1)
        output = self.fuse2(output)
        # return torch.sigmoid(output)
        output2 = F.interpolate(output2, scale_factor=4, mode='bilinear')
        if (self.pad):
            output = output[:,:,h:-h,:]
            output2 = output2[:,:,h:-h,:]
        
        return output, output2
        # return F.sigmoid(output), F.sigmoid(output2)
        
    def fit(self, X, mask, y):
        X, mask, y = X.to(self.device), mask.to(self.device), y.to(self.device)
        h1, h2 = self(X, mask)
        self.optimizer.zero_grad()
        loss = self.loss_fn(h1, y) + self.loss_fn(h2, y)
        loss.backward()
        self.optimizer.step()
        loss = loss.item()
        pred = (h2 + h1 > 0).int()
        numt = torch.sum(mask)
        TP = torch.sum(torch.minimum(y, pred)).item() / numt
        TN = torch.sum(torch.minimum(1-y, 1-pred) * mask).item() / numt
        FN = torch.sum(torch.minimum(y, 1-pred)).item() / numt
        FP = torch.sum(torch.minimum(1-y, pred)).item() / numt
        if (torch.any(TP < 0) or torch.any(TN < 0) or torch.any(FN < 0) or torch.any(FP < 0)):
            raise AssertionError()
        return (loss, torch.tensor([TP, FP, FN, TN]))

    def test(self, X, mask, y):
        X, mask, y = X.to(self.device), mask.to(self.device), y.to(self.device)
        h1, h2 = self(X, mask)
        loss = self.loss_fn(h1, y) + self.loss_fn(h2, y)
        loss = loss.item()
        pred = (h2 + h1 > 0).int()
        numt = torch.sum(mask)
        TP = torch.sum(torch.minimum(y, pred)).item() / numt
        TN = torch.sum(torch.minimum(1-y, 1-pred) * mask).item() / numt
        FN = torch.sum(torch.minimum(y, 1-pred)).item() / numt
        FP = torch.sum(torch.minimum(1-y, pred)).item() / numt
        return (loss, torch.tensor([TP, FP, FN, TN]))


In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from lib.pvtv2 import pvt_v2_b2
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# from mmcv.cnn import ConvModule
from torch.nn import Conv2d, UpsamplingBilinear2d
import warnings
import torch
# from mmcv.cnn import constant_init, kaiming_init
from torch import nn
from torchvision.transforms.functional import normalize
from PVTv2.pvtv2 import pvt_v2_b2
warnings.filterwarnings('ignore')

class SpatBlock(nn.Module):
    def __init__(self, dim=32, order=5):
        super().__init__()
        self.dim = dim
        self.order = order
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.blocks = nn.ModuleList([nn.Conv2d(1, dim, 1)] + [nn.Conv2d(1, dim, 1, bias=False) for _ in range(order-1)])

        self = self.to(self.device)

    def forward(self, x):
        B, C, H, W = x.shape

        pos = torch.arange(W) / W
        pos = torch.reshape(pos, (1, 1, 1, W)).to(self.device)
        
        func = torch.zeros_like(x)

        for i, block in enumerate(self.blocks):
            func += block(pos**(i+1))

        out = F.tanh(func) * x
        
        return out
        

def structure_loss(pred, mask):
    weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask) * weit).sum(dim=(2, 3))
    union = ((pred + mask) * weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1) / (union - inter + 1)

    return (wbce + wiou).mean()

# def structure_loss(pred, mask):
#     pred = torch.sigmoid(pred)
#     loss1 = F.binary_cross_entropy(pred, mask)
#     tp = pred * mask
#     fp = pred * (1 - mask)
#     fn = (1 - pred) * mask
#     tp = torch.mean(tp, dim=0)
#     fp = torch.mean(fp, dim=0)
#     fn = torch.mean(fn, dim=0)
#     iou = tp / (tp + fp + fn)
#     iou = torch.mean(iou)
#     loss2 = 1 - iou
#     return loss1 + loss2

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()

        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
        
class Block(nn.Sequential):
    def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True, norm_layer=nn.BatchNorm2d):
        super(Block, self).__init__()
        if bn_start:
            self.add_module('norm1', norm_layer(input_num)),

        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)),

        self.add_module('norm2', norm_layer(num1)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3,
                                            dilation=dilation_rate, padding=dilation_rate)),
        self.drop_rate = drop_out

    def forward(self, _input):
        feature = super(Block, self).forward(_input)
        if self.drop_rate > 0:
            feature = F.dropout2d(feature, p=self.drop_rate, training=self.training)
        return feature


def Upsample(x, size, align_corners = False):
    """
    Wrapper Around the Upsample Call
    """
    return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners)

    
# def last_zero_init(m):
#     if isinstance(m, nn.Sequential):
#         constant_init(m[-1], val=0)
#     else:
#         constant_init(m, val=0)


class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_mul', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out + out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out



class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class ConvBranch(nn.Module):
    def __init__(self, in_features, hidden_features = None, out_features = None):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.SiLU(inplace=True)
        )
        self.conv6 = nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(inplace=True)
        )
        self.conv7 = nn.Sequential(
            nn.Conv2d(hidden_features, out_features, 1, bias=False),
            nn.ReLU(inplace=True)
        )
        self.ca = ChannelAttention(64)
        self.sa = SpatialAttention()
        self.sigmoid_spatial = nn.Sigmoid()
    
    def forward(self, x):
        res1 = x
        res2 = x
        x = self.conv1(x)        
        x = x + self.conv2(x)
        x = self.conv3(x)
        x = x + self.conv4(x)
        x = self.conv5(x)
        x = x + self.conv6(x)
        x = self.conv7(x)
        x_mask = self.sigmoid_spatial(x)
        res1 = res1 * x_mask
        return res2 + res1

              
class GLSA(nn.Module):

    def __init__(self, input_dim=512, embed_dim=32, k_s=3):
        super().__init__()
                      
        self.conv1_1 = BasicConv2d(embed_dim*2,embed_dim, 1)
        self.conv1_1_1 = BasicConv2d(input_dim//2,embed_dim,1)
        self.local_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
        self.global_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
        self.GlobelBlock = ContextBlock(inplanes= embed_dim, ratio=2)
        self.local = ConvBranch(in_features = embed_dim, hidden_features = embed_dim, out_features = embed_dim)

    def forward(self, x):
        b, c, h, w = x.size()
        x_0, x_1 = x.chunk(2,dim = 1)  
        
    # local block 
        local = self.local(self.local_11conv(x_0))
        
    # Globel block    
        Globel = self.GlobelBlock(self.global_11conv(x_1))

    # concat Globel + local
        x = torch.cat([local,Globel], dim=1)
        x = self.conv1_1(x)

        return x    

class SBA(nn.Module):

    def __init__(self,input_dim = 64):
        super().__init__()

        self.input_dim = input_dim

        self.d_in1 = BasicConv2d(input_dim//2, input_dim//2, 1)
        self.d_in2 = BasicConv2d(input_dim//2, input_dim//2, 1)       
                

        self.conv = nn.Sequential(BasicConv2d(input_dim, input_dim, 3,1,1), nn.Conv2d(input_dim, 1, kernel_size=1, bias=False))
        self.fc1 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
        self.fc2 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
        
        self.Sigmoid = nn.Sigmoid()
        
    def forward(self, H_feature, L_feature):

        L_feature = self.fc1(L_feature)
        H_feature = self.fc2(H_feature)
        
        g_L_feature =  self.Sigmoid(L_feature)
        g_H_feature = self.Sigmoid(H_feature)
        
        L_feature = self.d_in1(L_feature)
        H_feature = self.d_in2(H_feature)


        L_feature = L_feature + L_feature * g_L_feature + (1 - g_L_feature) * Upsample(g_H_feature * H_feature, size= L_feature.size()[2:], align_corners=False)
        H_feature = H_feature + H_feature * g_H_feature + (1 - g_H_feature) * Upsample(g_L_feature * L_feature, size= H_feature.size()[2:], align_corners=False) 
        
        H_feature = Upsample(H_feature, size = L_feature.size()[2:])
        out = self.conv(torch.cat([H_feature,L_feature], dim=1))
        return out
        
            
class DuAT(nn.Module):
    def __init__(self, dim=32, dims= [64, 128, 320, 512], learning_rate=None,  loss_fn=None, optimizer=None, device=None, weight_decay=None):
        super(DuAT, self).__init__()

                    
        if (device is None):
            self.device = DEVICE
        else:
            self.device = device

        if (learning_rate is None):
            self.learning_rate = 1e-5
        else:
            self.learning_rate = learning_rate
        
        if (weight_decay is None):
            self.weight_decay = 1e-5
        else:
            self.weight_decay = weight_decay

        self.dims = dims

        # self.pad = False
        self.backbone = pvt_v2_b2()  # [64, 128, 320, 512]
        
        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3]

        self.shuffle4 = nn.Conv2d(c4_in_channels, c4_in_channels, 1)
        self.shuffle3 = nn.Conv2d(c3_in_channels, c3_in_channels, 1)
        self.shuffle2 = nn.Conv2d(c2_in_channels, c2_in_channels, 1)
        self.shuffle1 = nn.Conv2d(c1_in_channels, c1_in_channels, 1)

        self.bn4 = nn.BatchNorm2d(c4_in_channels)
        self.bn3 = nn.BatchNorm2d(c3_in_channels)
        self.bn2 = nn.BatchNorm2d(c2_in_channels)
        self.bn1 = nn.BatchNorm2d(c1_in_channels)

        self.embed4 = SpatBlock(dim=c4_in_channels)
        self.embed3 = SpatBlock(dim=c3_in_channels)
        self.embed2 = SpatBlock(dim=c2_in_channels)
        self.embed1 = SpatBlock(dim=c1_in_channels)

        self.bn4 = nn.BatchNorm2d(c4_in_channels)
        self.bn3 = nn.BatchNorm2d(c3_in_channels)
        self.bn2 = nn.BatchNorm2d(c2_in_channels)
        self.bn1 = nn.BatchNorm2d(c1_in_channels)

        self.GLSA_c4 = GLSA(input_dim=c4_in_channels, embed_dim=dim)
        self.GLSA_c3 = GLSA(input_dim=c3_in_channels, embed_dim=dim)
        self.GLSA_c2 = GLSA(input_dim=c2_in_channels, embed_dim=dim)
        self.L_feature = BasicConv2d(c1_in_channels, dim, 3,1,1)
        
        self.SBA = SBA(input_dim = dim)
        self.fuse = BasicConv2d(dim * 2, dim, 1)
        self.fuse2 = nn.Sequential(BasicConv2d(dim*3, dim, 1,1),nn.BatchNorm2d(dim),nn.Conv2d(dim, 1, kernel_size=1, bias=False))

        

        if (loss_fn is None):
            self.loss_fn = structure_loss
        else:
            self.loss_fn = loss_fn

        if (optimizer is None):
            self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        else:
            self.optimizer = optimizer(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

        self = self.to(self.device)

    def forward(self, x, mask):
        # backbone
        x = x.to(self.device)
        c1, c2, c3, c4 = self.backbone(x)


        
        # c1 = self.bn1(self.shuffle1(c1))
        # c2 = self.bn2(self.shuffle2(c2))
        # c3 = self.bn3(self.shuffle3(c3))
        # c4 = self.bn4(self.shuffle4(c4))
        
        # c1 = torch.cat([c1,
        #                 torch.arange(c1.shape[2], dtype=c1.dtype, device=self.device).view(1,1,-1,1).repeat(c1.shape[0], 1, 1, c1.shape[3]) / c1.shape[2],
        #                 torch.arange(c1.shape[3], dtype=c1.dtype, device=self.device).view(1,1,1,-1).repeat(c1.shape[0], 1, c1.shape[2], 1) / c1.shape[3]], dim=1)
        # c2 = torch.cat([c2,
        #                 torch.arange(c2.shape[2], dtype=c2.dtype, device=self.device).view(1,1,-1,1).repeat(c2.shape[0], 1, 1, c2.shape[3]) / c2.shape[2],
        #                 torch.arange(c2.shape[3], dtype=c2.dtype, device=self.device).view(1,1,1,-1).repeat(c2.shape[0], 1, c2.shape[2], 1) / c2.shape[3]], dim=1)
        # c3 = torch.cat([c3,
        #                 torch.arange(c3.shape[2], dtype=c3.dtype, device=self.device).view(1,1,-1,1).repeat(c1.shape[0], 1, 1, c3.shape[3]) / c3.shape[2],
        #                 torch.arange(c3.shape[3], dtype=c3.dtype, device=self.device).view(1,1,1,-1).repeat(c1.shape[0], 1, c3.shape[2], 1) / c3.shape[3]], dim=1)
        # c4 = torch.cat([c4,
        #                 torch.arange(c4.shape[2], dtype=c4.dtype, device=self.device).view(1,1,-1,1).repeat(c4.shape[0], 1, 1, c4.shape[3]) / c4.shape[2],
        #                 torch.arange(c4.shape[3], dtype=c4.dtype, device=self.device).view(1,1,1,-1).repeat(c4.shape[0], 1, c4.shape[2], 1) / c4.shape[3]], dim=1)

        c1 = self.bn1(self.embed1(c1))
        c2 = self.bn2(self.embed2(c2))
        c3 = self.bn3(self.embed3(c3))
        c4 = self.bn4(self.embed4(c4))
        
        n, _, h, w = c4.shape        
        _c4 = self.GLSA_c4(c4) # [1, 64, 11, 11]
        _c4 = Upsample(_c4, c3.size()[2:])
        _c3 = self.GLSA_c3(c3) # [1, 64, 22, 22]
        _c2 = self.GLSA_c2(c2) # [1, 64, 44, 44]
        
        output = torch.cat([Upsample(_c4, c2.size()[2:]), Upsample(_c3, c2.size()[2:]), _c2], dim=1)
        
        L_feature = self.L_feature(c1)  # [1, 64, 88, 88]
        H_feature = self.fuse(torch.cat([_c4, _c3], dim=1))
        H_feature = Upsample(H_feature,c2.size()[2:])
        
        output2 = self.SBA(H_feature,L_feature)
        h = x.shape[2] // 4
        output = F.interpolate(output, scale_factor=8, mode='bicubic')
        # output = torch.cat([output, Upsample(c1, output.size()[2:])], dim=1)
        output = self.fuse2(output)
        # return torch.sigmoid(output)
        output2 = F.interpolate(output2, scale_factor=4, mode='bicubic')

        
        return output, output2
        # return F.sigmoid(output[:,:,h:-h,:])
        
    def fit(self, X, mask, y):
        X, mask, y = X.to(self.device), mask.to(self.device), y.to(self.device)
        h1, h2 = self(X, mask)
        self.optimizer.zero_grad()
        loss = self.loss_fn(h1, y) + self.loss_fn(h2, y)
        loss.backward()
        self.optimizer.step()
        loss = loss.item()
        pred = (h2 > 0).int()
        numt = torch.sum(mask)
        TP = torch.sum(torch.minimum(y, pred)).item() / numt
        TN = torch.sum(torch.minimum(1-y, 1-pred) * mask).item() / numt
        FN = torch.sum(torch.minimum(y, 1-pred)).item() / numt
        FP = torch.sum(torch.minimum(1-y, pred)).item() / numt
        return (loss, torch.tensor([TP, FP, FN, TN]))

    def test(self, X, mask, y):
        X, mask, y = X.to(self.device), mask.to(self.device), y.to(self.device)
        h1, h2 = self(X, mask)
        loss = self.loss_fn(h1, y) + self.loss_fn(h2, y)
        loss = loss.item()
        pred = (h2 > 0).int()
        numt = torch.sum(mask)
        TP = torch.sum(torch.minimum(y, pred)).item() / numt
        TN = torch.sum(torch.minimum(1-y, 1-pred) * mask).item() / numt
        FN = torch.sum(torch.minimum(y, 1-pred)).item() / numt
        FP = torch.sum(torch.minimum(1-y, pred)).item() / numt
        return (loss, torch.tensor([TP, FP, FN, TN]))


# Cartesian test

In [14]:

def process_tensor_cart(tensor):
    tensor_np = tensor
    kernel = np.ones((3,3),np.uint8)
    closing = cv.morphologyEx(tensor_np, cv.MORPH_ERODE, kernel)
    cleaned = closing
    cleaned = cv.morphologyEx(cleaned, cv.MORPH_DILATE, kernel)
    
    kernel = np.ones((10, 10),np.uint8)

    cleaned = cv.morphologyEx(cleaned, cv.MORPH_CLOSE, kernel)
    return cleaned
    
process_tensor = process_tensor_cart

In [18]:
results = []

def output(model):
    num_batches = len(test_dl)
    model.eval()
    with torch.no_grad():                
        num_batches = len(test_dl)
        dices = []
        with tqdm(total=num_batches) as pbar:
            pbar.set_description("Validation Dataset")
            for (X, mask_l, mask_s, y, filename, center) in test_dl:
                X = X.to(DEVICE)
                mask_s = mask_s.to(DEVICE)
                h1, h2 = model(X, mask_s)
                h = h1 + h2
                h = h.detach().cpu().numpy()[0,0]
                h = cv.resize(h, (512, 384))
                h = np.uint8(h > .0)
                h = process_tensor(h)
                y_tilde = np.round(h)
                y = to_cart(y.detach().cpu().numpy()[0,0], (int(center[0].item() * 512), int(center[1].item() * 384)))
                y = np.round(y)
                # print(type(y), type(y_tilde))
                tp = np.sum(np.minimum(y_tilde, y))
                fp = np.sum(np.minimum(y_tilde, 1 - y))
                fn = np.sum(np.minimum(1 - y_tilde, y))
                tn = np.sum(np.minimum(1 - y_tilde, 1 - y))
                dice = 2 * tp / (2 * tp + fp + fn)
                dices.append(dice)
                # img = np.concatenate([y_tilde * 255, y * 255], axis=1)
                # cv.imwrite(os.path.join('Final-Improv', f'{dice:.4f}-' + filename[0] + '.jpg'), img)
                pbar.update(1)
        print(np.mean(np.array(dices)))
        results.append(np.mean(np.array(dices)))

In [19]:
def raw_output(model):
    num_batches = len(test_dl)
    model.eval()
    with torch.no_grad():                
        num_batches = len(test_dl)
        dices = []
        with tqdm(total=num_batches) as pbar:
            pbar.set_description("Test Dataset")
            for (X, mask_l, mask_s, y, filename, center) in test_dl:
                X = X.to(DEVICE)
                mask_s = mask_s.to(DEVICE)
                h1, h2 = model(X, mask_s)
                h = h1 + h2
                h = h.detach().cpu().numpy()[0,0]
                h = cv.resize(h, (512, 384))
                h = np.uint8(h > .0)
                y_tilde = np.round(h)
                y = y.detach().cpu().numpy()[0,0]
                y = np.round(y)
                # print(type(y), type(y_tilde))
                tp = np.sum(np.minimum(y_tilde, y))
                fp = np.sum(np.minimum(y_tilde, 1 - y))
                fn = np.sum(np.minimum(1 - y_tilde, y))
                tn = np.sum(np.minimum(1 - y_tilde, 1 - y))
                dice = 2 * tp / (2 * tp + fp + fn)
                dices.append(dice)
                # img = np.concatenate([y_tilde * 255, y * 255], axis=1)
                # cv.imwrite(os.path.join('Final-Improv', f'{dice:.4f}-' + filename[0] + '.jpg'), img)
                pbar.update(1)
        print(np.mean(np.array(dices)))
        results_pol.append(np.mean(np.array(dices)))

# Training

In [33]:
results_pol = []
results = []
model_name = 'Cart-FocalLoss'
for iter in range(1,6):
    print(f'Run {iter} ===========')
    net = DuAT(dim=32, optimizer=torch.optim.AdamW, learning_rate=1e-4, weight_decay=3e-4, loss_fn=FocalLoss(1))
    maxdice = 0.
    
    for t in range(30):
        torch.cuda.empty_cache()
        print(f"Epoch {t+1} ---------------------")
        print(f"Training Set -----")
        train_loss, _ = train(net)
        print(f"Validation Set -----")
        test_loss, dice = test(net)
        print(f"\rDice: {dice: .4f}, Loss: {test_loss : .4f}")
        if dice > maxdice:
            maxdice = dice
            torch.save(net, f'FinalModels/{model_name}-{iter}.pth')
            
    print(f"Test Set -----")
    net = torch.load(f'FinalModels/{model_name}-{iter}.pth')
    output(net)
    raw_output(net)

print('FINAL RESULTS:')
print(f'CARTESIAN={np.mean(np.array(results_pol)):.4f}+/-{np.std(np.array(results_pol)):.4f}')

Epoch 1 ---------------------
Training Set -----


Avg. Loss: 0.3366, Avg. Accuracy: 0.9073, Avg. Dice: 0.7793: 100%|███████████████████| 519/519 [15:44<00:00,  1.82s/it]


Validation Set -----


Avg. Loss: 0.2166, Avg. Accuracy: 0.9380, Avg. Dice: 0.8319: 100%|█████████████████████| 65/65 [01:27<00:00,  1.35s/it]


Dice:  0.8319, Loss:  0.2166
Epoch 2 ---------------------
Training Set -----


Avg. Loss: 0.2411, Avg. Accuracy: 0.9313, Avg. Dice: 0.8256: 100%|███████████████████| 519/519 [16:24<00:00,  1.90s/it]


Validation Set -----


Avg. Loss: 0.1857, Avg. Accuracy: 0.9416, Avg. Dice: 0.8520: 100%|█████████████████████| 65/65 [01:29<00:00,  1.37s/it]


Dice:  0.8520, Loss:  0.1857
Epoch 3 ---------------------
Training Set -----


Avg. Loss: 0.1986, Avg. Accuracy: 0.9382, Avg. Dice: 0.8445: 100%|███████████████████| 519/519 [16:24<00:00,  1.90s/it]


Validation Set -----


Avg. Loss: 0.1867, Avg. Accuracy: 0.9403, Avg. Dice: 0.8518: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8518, Loss:  0.1867
Epoch 4 ---------------------
Training Set -----


Avg. Loss: 0.1773, Avg. Accuracy: 0.9429, Avg. Dice: 0.8551: 100%|███████████████████| 519/519 [16:26<00:00,  1.90s/it]


Validation Set -----


Avg. Loss: 0.1778, Avg. Accuracy: 0.9441, Avg. Dice: 0.8625: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8625, Loss:  0.1778
Epoch 5 ---------------------
Training Set -----


Avg. Loss: 0.1567, Avg. Accuracy: 0.9488, Avg. Dice: 0.8699: 100%|███████████████████| 519/519 [16:26<00:00,  1.90s/it]


Validation Set -----


Avg. Loss: 0.1801, Avg. Accuracy: 0.9381, Avg. Dice: 0.8445: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8445, Loss:  0.1801
Epoch 6 ---------------------
Training Set -----


Avg. Loss: 0.1511, Avg. Accuracy: 0.9488, Avg. Dice: 0.8706: 100%|███████████████████| 519/519 [16:28<00:00,  1.91s/it]


Validation Set -----


Avg. Loss: 0.1629, Avg. Accuracy: 0.9431, Avg. Dice: 0.8380: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8380, Loss:  0.1629
Epoch 7 ---------------------
Training Set -----


Avg. Loss: 0.1350, Avg. Accuracy: 0.9528, Avg. Dice: 0.8780: 100%|███████████████████| 519/519 [16:30<00:00,  1.91s/it]


Validation Set -----


Avg. Loss: 0.1686, Avg. Accuracy: 0.9425, Avg. Dice: 0.8593: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8593, Loss:  0.1686
Epoch 8 ---------------------
Training Set -----


Avg. Loss: 0.1228, Avg. Accuracy: 0.9566, Avg. Dice: 0.8916: 100%|███████████████████| 519/519 [16:29<00:00,  1.91s/it]


Validation Set -----


Avg. Loss: 0.1624, Avg. Accuracy: 0.9438, Avg. Dice: 0.8597: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8597, Loss:  0.1624
Epoch 9 ---------------------
Training Set -----


Avg. Loss: 0.1226, Avg. Accuracy: 0.9557, Avg. Dice: 0.8881: 100%|███████████████████| 519/519 [16:33<00:00,  1.91s/it]


Validation Set -----


Avg. Loss: 0.1906, Avg. Accuracy: 0.9222, Avg. Dice: 0.8192: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8192, Loss:  0.1906
Epoch 10 ---------------------
Training Set -----


Avg. Loss: 0.1151, Avg. Accuracy: 0.9580, Avg. Dice: 0.8931: 100%|███████████████████| 519/519 [16:41<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1728, Avg. Accuracy: 0.9414, Avg. Dice: 0.8466: 100%|█████████████████████| 65/65 [01:28<00:00,  1.37s/it]


Dice:  0.8466, Loss:  0.1728
Epoch 11 ---------------------
Training Set -----


Avg. Loss: 0.1049, Avg. Accuracy: 0.9607, Avg. Dice: 0.9011: 100%|███████████████████| 519/519 [16:31<00:00,  1.91s/it]


Validation Set -----


Avg. Loss: 0.1893, Avg. Accuracy: 0.9410, Avg. Dice: 0.8573: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8573, Loss:  0.1893
Epoch 12 ---------------------
Training Set -----


Avg. Loss: 0.1015, Avg. Accuracy: 0.9620, Avg. Dice: 0.9049: 100%|███████████████████| 519/519 [16:32<00:00,  1.91s/it]


Validation Set -----


Avg. Loss: 0.1855, Avg. Accuracy: 0.9410, Avg. Dice: 0.8590: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8590, Loss:  0.1855
Epoch 13 ---------------------
Training Set -----


Avg. Loss: 0.0898, Avg. Accuracy: 0.9658, Avg. Dice: 0.9137: 100%|███████████████████| 519/519 [16:42<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1636, Avg. Accuracy: 0.9440, Avg. Dice: 0.8654: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8654, Loss:  0.1636
Epoch 14 ---------------------
Training Set -----


Avg. Loss: 0.0892, Avg. Accuracy: 0.9657, Avg. Dice: 0.9131: 100%|███████████████████| 519/519 [16:37<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1940, Avg. Accuracy: 0.9390, Avg. Dice: 0.8414: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8414, Loss:  0.1940
Epoch 15 ---------------------
Training Set -----


Avg. Loss: 0.0869, Avg. Accuracy: 0.9666, Avg. Dice: 0.9165: 100%|███████████████████| 519/519 [16:37<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1594, Avg. Accuracy: 0.9467, Avg. Dice: 0.8752: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8752, Loss:  0.1594
Epoch 16 ---------------------
Training Set -----


Avg. Loss: 0.0788, Avg. Accuracy: 0.9693, Avg. Dice: 0.9223: 100%|███████████████████| 519/519 [16:36<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.2053, Avg. Accuracy: 0.9419, Avg. Dice: 0.8553: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8553, Loss:  0.2053
Epoch 17 ---------------------
Training Set -----


Avg. Loss: 0.0830, Avg. Accuracy: 0.9679, Avg. Dice: 0.9190: 100%|███████████████████| 519/519 [16:34<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1679, Avg. Accuracy: 0.9440, Avg. Dice: 0.8582: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8582, Loss:  0.1679
Epoch 18 ---------------------
Training Set -----


Avg. Loss: 0.0790, Avg. Accuracy: 0.9691, Avg. Dice: 0.9216: 100%|███████████████████| 519/519 [16:40<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2250, Avg. Accuracy: 0.9390, Avg. Dice: 0.8464: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8464, Loss:  0.2250
Epoch 19 ---------------------
Training Set -----


Avg. Loss: 0.0735, Avg. Accuracy: 0.9713, Avg. Dice: 0.9260: 100%|███████████████████| 519/519 [16:38<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.2082, Avg. Accuracy: 0.9414, Avg. Dice: 0.8484: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8484, Loss:  0.2082
Epoch 20 ---------------------
Training Set -----


Avg. Loss: 0.0723, Avg. Accuracy: 0.9721, Avg. Dice: 0.9286: 100%|███████████████████| 519/519 [16:35<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.2023, Avg. Accuracy: 0.9367, Avg. Dice: 0.8320: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8320, Loss:  0.2023
Epoch 21 ---------------------
Training Set -----


Avg. Loss: 0.0751, Avg. Accuracy: 0.9704, Avg. Dice: 0.9244: 100%|███████████████████| 519/519 [16:41<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2129, Avg. Accuracy: 0.9412, Avg. Dice: 0.8563: 100%|█████████████████████| 65/65 [01:30<00:00,  1.38s/it]


Dice:  0.8563, Loss:  0.2129
Epoch 22 ---------------------
Training Set -----


Avg. Loss: 0.0895, Avg. Accuracy: 0.9659, Avg. Dice: 0.9120: 100%|███████████████████| 519/519 [16:38<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1872, Avg. Accuracy: 0.9420, Avg. Dice: 0.8503: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8503, Loss:  0.1872
Epoch 23 ---------------------
Training Set -----


Avg. Loss: 0.0674, Avg. Accuracy: 0.9736, Avg. Dice: 0.9341: 100%|███████████████████| 519/519 [16:38<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1940, Avg. Accuracy: 0.9439, Avg. Dice: 0.8635: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8635, Loss:  0.1940
Epoch 24 ---------------------
Training Set -----


Avg. Loss: 0.0632, Avg. Accuracy: 0.9754, Avg. Dice: 0.9377: 100%|███████████████████| 519/519 [16:40<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2021, Avg. Accuracy: 0.9447, Avg. Dice: 0.8596: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8596, Loss:  0.2021
Epoch 25 ---------------------
Training Set -----


Avg. Loss: 0.0582, Avg. Accuracy: 0.9770, Avg. Dice: 0.9413: 100%|███████████████████| 519/519 [16:38<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1953, Avg. Accuracy: 0.9454, Avg. Dice: 0.8722: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8722, Loss:  0.1953
Epoch 26 ---------------------
Training Set -----


Avg. Loss: 0.0553, Avg. Accuracy: 0.9781, Avg. Dice: 0.9430: 100%|███████████████████| 519/519 [16:37<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.2052, Avg. Accuracy: 0.9461, Avg. Dice: 0.8651: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8651, Loss:  0.2052
Epoch 27 ---------------------
Training Set -----


Avg. Loss: 0.0643, Avg. Accuracy: 0.9750, Avg. Dice: 0.9361: 100%|███████████████████| 519/519 [16:35<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.2013, Avg. Accuracy: 0.9445, Avg. Dice: 0.8565: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8565, Loss:  0.2013
Epoch 28 ---------------------
Training Set -----


Avg. Loss: 0.0615, Avg. Accuracy: 0.9758, Avg. Dice: 0.9379: 100%|███████████████████| 519/519 [16:45<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.2083, Avg. Accuracy: 0.9412, Avg. Dice: 0.8522: 100%|█████████████████████| 65/65 [01:29<00:00,  1.37s/it]


Dice:  0.8522, Loss:  0.2083
Epoch 29 ---------------------
Training Set -----


Avg. Loss: 0.0522, Avg. Accuracy: 0.9793, Avg. Dice: 0.9473: 100%|███████████████████| 519/519 [16:39<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2115, Avg. Accuracy: 0.9467, Avg. Dice: 0.8672: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8672, Loss:  0.2115
Epoch 30 ---------------------
Training Set -----


Avg. Loss: 0.0526, Avg. Accuracy: 0.9790, Avg. Dice: 0.9463: 100%|███████████████████| 519/519 [16:42<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1901, Avg. Accuracy: 0.9491, Avg. Dice: 0.8693: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8693, Loss:  0.1901
Test Set -----


Validation Dataset: 100%|████████████████████████████████████████████████████████████| 259/259 [01:24<00:00,  3.06it/s]


0.1758590597200691


Test Dataset: 100%|██████████████████████████████████████████████████████████████████| 259/259 [01:21<00:00,  3.18it/s]


0.8799338612853709
Epoch 1 ---------------------
Training Set -----


Avg. Loss: 0.3458, Avg. Accuracy: 0.9062, Avg. Dice: 0.7780: 100%|███████████████████| 519/519 [16:37<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.2566, Avg. Accuracy: 0.9224, Avg. Dice: 0.8027: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8027, Loss:  0.2566
Epoch 2 ---------------------
Training Set -----


Avg. Loss: 0.2452, Avg. Accuracy: 0.9300, Avg. Dice: 0.8197: 100%|███████████████████| 519/519 [16:44<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.2026, Avg. Accuracy: 0.9373, Avg. Dice: 0.8365: 100%|█████████████████████| 65/65 [01:28<00:00,  1.37s/it]


Dice:  0.8365, Loss:  0.2026
Epoch 3 ---------------------
Training Set -----


Avg. Loss: 0.2057, Avg. Accuracy: 0.9369, Avg. Dice: 0.8399: 100%|███████████████████| 519/519 [16:39<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1848, Avg. Accuracy: 0.9416, Avg. Dice: 0.8521: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8521, Loss:  0.1848
Epoch 4 ---------------------
Training Set -----


Avg. Loss: 0.1821, Avg. Accuracy: 0.9403, Avg. Dice: 0.8481: 100%|███████████████████| 519/519 [16:34<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1570, Avg. Accuracy: 0.9439, Avg. Dice: 0.8584: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8584, Loss:  0.1570
Epoch 5 ---------------------
Training Set -----


Avg. Loss: 0.1595, Avg. Accuracy: 0.9463, Avg. Dice: 0.8644: 100%|███████████████████| 519/519 [16:45<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1598, Avg. Accuracy: 0.9461, Avg. Dice: 0.8677: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8677, Loss:  0.1598
Epoch 6 ---------------------
Training Set -----


Avg. Loss: 0.1417, Avg. Accuracy: 0.9515, Avg. Dice: 0.8770: 100%|███████████████████| 519/519 [16:41<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2857, Avg. Accuracy: 0.8961, Avg. Dice: 0.7714: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.7714, Loss:  0.2857
Epoch 7 ---------------------
Training Set -----


Avg. Loss: 0.1373, Avg. Accuracy: 0.9513, Avg. Dice: 0.8788: 100%|███████████████████| 519/519 [16:36<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.1745, Avg. Accuracy: 0.9376, Avg. Dice: 0.8308: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8308, Loss:  0.1745
Epoch 8 ---------------------
Training Set -----


Avg. Loss: 0.1173, Avg. Accuracy: 0.9583, Avg. Dice: 0.8929: 100%|███████████████████| 519/519 [16:41<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1900, Avg. Accuracy: 0.9380, Avg. Dice: 0.8484: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8484, Loss:  0.1900
Epoch 9 ---------------------
Training Set -----


Avg. Loss: 0.1166, Avg. Accuracy: 0.9577, Avg. Dice: 0.8926: 100%|███████████████████| 519/519 [16:38<00:00,  1.92s/it]


Validation Set -----


Avg. Loss: 0.2905, Avg. Accuracy: 0.9128, Avg. Dice: 0.8017: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8017, Loss:  0.2905
Epoch 10 ---------------------
Training Set -----


Avg. Loss: 0.1240, Avg. Accuracy: 0.9550, Avg. Dice: 0.8859: 100%|███████████████████| 519/519 [16:50<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1845, Avg. Accuracy: 0.9391, Avg. Dice: 0.8492: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8492, Loss:  0.1845
Epoch 11 ---------------------
Training Set -----


Avg. Loss: 0.1093, Avg. Accuracy: 0.9601, Avg. Dice: 0.8981: 100%|███████████████████| 519/519 [16:40<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1793, Avg. Accuracy: 0.9428, Avg. Dice: 0.8506: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8506, Loss:  0.1793
Epoch 12 ---------------------
Training Set -----


Avg. Loss: 0.1012, Avg. Accuracy: 0.9623, Avg. Dice: 0.9043: 100%|███████████████████| 519/519 [16:41<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2376, Avg. Accuracy: 0.9266, Avg. Dice: 0.7989: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.7989, Loss:  0.2376
Epoch 13 ---------------------
Training Set -----


Avg. Loss: 0.1020, Avg. Accuracy: 0.9619, Avg. Dice: 0.9037: 100%|███████████████████| 519/519 [16:43<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1658, Avg. Accuracy: 0.9450, Avg. Dice: 0.8572: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8572, Loss:  0.1658
Epoch 14 ---------------------
Training Set -----


Avg. Loss: 0.0882, Avg. Accuracy: 0.9665, Avg. Dice: 0.9165: 100%|███████████████████| 519/519 [16:43<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1722, Avg. Accuracy: 0.9452, Avg. Dice: 0.8558: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8558, Loss:  0.1722
Epoch 15 ---------------------
Training Set -----


Avg. Loss: 0.0890, Avg. Accuracy: 0.9667, Avg. Dice: 0.9132: 100%|███████████████████| 519/519 [16:44<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.2239, Avg. Accuracy: 0.9348, Avg. Dice: 0.8261: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8261, Loss:  0.2239
Epoch 16 ---------------------
Training Set -----


Avg. Loss: 0.0869, Avg. Accuracy: 0.9667, Avg. Dice: 0.9153: 100%|███████████████████| 519/519 [16:40<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1788, Avg. Accuracy: 0.9460, Avg. Dice: 0.8705: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8705, Loss:  0.1788
Epoch 17 ---------------------
Training Set -----


Avg. Loss: 0.0753, Avg. Accuracy: 0.9709, Avg. Dice: 0.9255: 100%|███████████████████| 519/519 [16:44<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1721, Avg. Accuracy: 0.9457, Avg. Dice: 0.8620: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8620, Loss:  0.1721
Epoch 18 ---------------------
Training Set -----


Avg. Loss: 0.0845, Avg. Accuracy: 0.9673, Avg. Dice: 0.9175: 100%|███████████████████| 519/519 [16:44<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1877, Avg. Accuracy: 0.9465, Avg. Dice: 0.8727: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8727, Loss:  0.1877
Epoch 19 ---------------------
Training Set -----


Avg. Loss: 0.1000, Avg. Accuracy: 0.9628, Avg. Dice: 0.9086: 100%|███████████████████| 519/519 [16:46<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1724, Avg. Accuracy: 0.9437, Avg. Dice: 0.8624: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8624, Loss:  0.1724
Epoch 20 ---------------------
Training Set -----


Avg. Loss: 0.0741, Avg. Accuracy: 0.9713, Avg. Dice: 0.9268: 100%|███████████████████| 519/519 [16:45<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1624, Avg. Accuracy: 0.9473, Avg. Dice: 0.8805: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8805, Loss:  0.1624
Epoch 21 ---------------------
Training Set -----


Avg. Loss: 0.0689, Avg. Accuracy: 0.9732, Avg. Dice: 0.9315: 100%|███████████████████| 519/519 [16:42<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1824, Avg. Accuracy: 0.9454, Avg. Dice: 0.8588: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8588, Loss:  0.1824
Epoch 22 ---------------------
Training Set -----


Avg. Loss: 0.0643, Avg. Accuracy: 0.9748, Avg. Dice: 0.9357: 100%|███████████████████| 519/519 [16:46<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1735, Avg. Accuracy: 0.9472, Avg. Dice: 0.8666: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8666, Loss:  0.1735
Epoch 23 ---------------------
Training Set -----


Avg. Loss: 0.0658, Avg. Accuracy: 0.9741, Avg. Dice: 0.9331: 100%|███████████████████| 519/519 [16:41<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1846, Avg. Accuracy: 0.9452, Avg. Dice: 0.8539: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8539, Loss:  0.1846
Epoch 24 ---------------------
Training Set -----


Avg. Loss: 0.0628, Avg. Accuracy: 0.9754, Avg. Dice: 0.9364: 100%|███████████████████| 519/519 [16:43<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2208, Avg. Accuracy: 0.9382, Avg. Dice: 0.8470: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8470, Loss:  0.2208
Epoch 25 ---------------------
Training Set -----


Avg. Loss: 0.0618, Avg. Accuracy: 0.9756, Avg. Dice: 0.9378: 100%|███████████████████| 519/519 [16:46<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1781, Avg. Accuracy: 0.9444, Avg. Dice: 0.8628: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8628, Loss:  0.1781
Epoch 26 ---------------------
Training Set -----


Avg. Loss: 0.0637, Avg. Accuracy: 0.9748, Avg. Dice: 0.9368: 100%|███████████████████| 519/519 [16:43<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1809, Avg. Accuracy: 0.9440, Avg. Dice: 0.8588: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8588, Loss:  0.1809
Epoch 27 ---------------------
Training Set -----


Avg. Loss: 0.0730, Avg. Accuracy: 0.9715, Avg. Dice: 0.9269: 100%|███████████████████| 519/519 [16:49<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.2061, Avg. Accuracy: 0.9430, Avg. Dice: 0.8540: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8540, Loss:  0.2061
Epoch 28 ---------------------
Training Set -----


Avg. Loss: 0.0566, Avg. Accuracy: 0.9778, Avg. Dice: 0.9433: 100%|███████████████████| 519/519 [16:51<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1973, Avg. Accuracy: 0.9453, Avg. Dice: 0.8672: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8672, Loss:  0.1973
Epoch 29 ---------------------
Training Set -----


Avg. Loss: 0.0510, Avg. Accuracy: 0.9797, Avg. Dice: 0.9480: 100%|███████████████████| 519/519 [16:48<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1945, Avg. Accuracy: 0.9471, Avg. Dice: 0.8631: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8631, Loss:  0.1945
Epoch 30 ---------------------
Training Set -----


Avg. Loss: 0.0514, Avg. Accuracy: 0.9796, Avg. Dice: 0.9477: 100%|███████████████████| 519/519 [16:39<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.1913, Avg. Accuracy: 0.9462, Avg. Dice: 0.8586: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8586, Loss:  0.1913
Test Set -----


Validation Dataset: 100%|████████████████████████████████████████████████████████████| 259/259 [01:24<00:00,  3.06it/s]


0.1686605494285631


Test Dataset: 100%|██████████████████████████████████████████████████████████████████| 259/259 [01:22<00:00,  3.15it/s]


0.8776641064544358
Epoch 1 ---------------------
Training Set -----


Avg. Loss: 0.3364, Avg. Accuracy: 0.9071, Avg. Dice: 0.7798: 100%|███████████████████| 519/519 [16:45<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.2393, Avg. Accuracy: 0.9348, Avg. Dice: 0.8380: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8380, Loss:  0.2393
Epoch 2 ---------------------
Training Set -----


Avg. Loss: 0.2381, Avg. Accuracy: 0.9321, Avg. Dice: 0.8219: 100%|███████████████████| 519/519 [16:48<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.2123, Avg. Accuracy: 0.9353, Avg. Dice: 0.8445: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8445, Loss:  0.2123
Epoch 3 ---------------------
Training Set -----


Avg. Loss: 0.1956, Avg. Accuracy: 0.9382, Avg. Dice: 0.8453: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.2481, Avg. Accuracy: 0.9242, Avg. Dice: 0.8169: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8169, Loss:  0.2481
Epoch 4 ---------------------
Training Set -----


Avg. Loss: 0.1754, Avg. Accuracy: 0.9427, Avg. Dice: 0.8553: 100%|███████████████████| 519/519 [16:48<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1861, Avg. Accuracy: 0.9350, Avg. Dice: 0.8392: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8392, Loss:  0.1861
Epoch 5 ---------------------
Training Set -----


Avg. Loss: 0.1646, Avg. Accuracy: 0.9439, Avg. Dice: 0.8570: 100%|███████████████████| 519/519 [16:51<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1778, Avg. Accuracy: 0.9373, Avg. Dice: 0.8487: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8487, Loss:  0.1778
Epoch 6 ---------------------
Training Set -----


Avg. Loss: 0.1409, Avg. Accuracy: 0.9520, Avg. Dice: 0.8773: 100%|███████████████████| 519/519 [16:51<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.2030, Avg. Accuracy: 0.9325, Avg. Dice: 0.8211: 100%|█████████████████████| 65/65 [01:29<00:00,  1.38s/it]


Dice:  0.8211, Loss:  0.2030
Epoch 7 ---------------------
Training Set -----


Avg. Loss: 0.1297, Avg. Accuracy: 0.9542, Avg. Dice: 0.8850: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1869, Avg. Accuracy: 0.9422, Avg. Dice: 0.8614: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8614, Loss:  0.1869
Epoch 8 ---------------------
Training Set -----


Avg. Loss: 0.1251, Avg. Accuracy: 0.9547, Avg. Dice: 0.8870: 100%|███████████████████| 519/519 [16:48<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1699, Avg. Accuracy: 0.9419, Avg. Dice: 0.8489: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8489, Loss:  0.1699
Epoch 9 ---------------------
Training Set -----


Avg. Loss: 0.1135, Avg. Accuracy: 0.9584, Avg. Dice: 0.8966: 100%|███████████████████| 519/519 [16:50<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1826, Avg. Accuracy: 0.9447, Avg. Dice: 0.8715: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8715, Loss:  0.1826
Epoch 10 ---------------------
Training Set -----


Avg. Loss: 0.1077, Avg. Accuracy: 0.9604, Avg. Dice: 0.8973: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1818, Avg. Accuracy: 0.9423, Avg. Dice: 0.8663: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8663, Loss:  0.1818
Epoch 11 ---------------------
Training Set -----


Avg. Loss: 0.0976, Avg. Accuracy: 0.9635, Avg. Dice: 0.9078: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1846, Avg. Accuracy: 0.9440, Avg. Dice: 0.8550: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8550, Loss:  0.1846
Epoch 12 ---------------------
Training Set -----


Avg. Loss: 0.0961, Avg. Accuracy: 0.9639, Avg. Dice: 0.9089: 100%|███████████████████| 519/519 [16:47<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1737, Avg. Accuracy: 0.9454, Avg. Dice: 0.8618: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8618, Loss:  0.1737
Epoch 13 ---------------------
Training Set -----


Avg. Loss: 0.0909, Avg. Accuracy: 0.9653, Avg. Dice: 0.9137: 100%|███████████████████| 519/519 [16:45<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1822, Avg. Accuracy: 0.9409, Avg. Dice: 0.8437: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8437, Loss:  0.1822
Epoch 14 ---------------------
Training Set -----


Avg. Loss: 0.0859, Avg. Accuracy: 0.9672, Avg. Dice: 0.9168: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1818, Avg. Accuracy: 0.9418, Avg. Dice: 0.8525: 100%|█████████████████████| 65/65 [01:35<00:00,  1.47s/it]


Dice:  0.8525, Loss:  0.1818
Epoch 15 ---------------------
Training Set -----


Avg. Loss: 0.1069, Avg. Accuracy: 0.9600, Avg. Dice: 0.8995: 100%|███████████████████| 519/519 [17:24<00:00,  2.01s/it]


Validation Set -----


Avg. Loss: 0.1563, Avg. Accuracy: 0.9458, Avg. Dice: 0.8667: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8667, Loss:  0.1563
Epoch 16 ---------------------
Training Set -----


Avg. Loss: 0.0904, Avg. Accuracy: 0.9653, Avg. Dice: 0.9121: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1704, Avg. Accuracy: 0.9444, Avg. Dice: 0.8632: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8632, Loss:  0.1704
Epoch 17 ---------------------
Training Set -----


Avg. Loss: 0.0933, Avg. Accuracy: 0.9642, Avg. Dice: 0.9107: 100%|███████████████████| 519/519 [16:47<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1689, Avg. Accuracy: 0.9457, Avg. Dice: 0.8685: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8685, Loss:  0.1689
Epoch 18 ---------------------
Training Set -----


Avg. Loss: 0.0756, Avg. Accuracy: 0.9707, Avg. Dice: 0.9251: 100%|███████████████████| 519/519 [16:45<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1821, Avg. Accuracy: 0.9470, Avg. Dice: 0.8731: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8731, Loss:  0.1821
Epoch 19 ---------------------
Training Set -----


Avg. Loss: 0.0742, Avg. Accuracy: 0.9712, Avg. Dice: 0.9259: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1992, Avg. Accuracy: 0.9451, Avg. Dice: 0.8651: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8651, Loss:  0.1992
Epoch 20 ---------------------
Training Set -----


Avg. Loss: 0.0677, Avg. Accuracy: 0.9735, Avg. Dice: 0.9310: 100%|███████████████████| 519/519 [16:51<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1862, Avg. Accuracy: 0.9479, Avg. Dice: 0.8727: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8727, Loss:  0.1862
Epoch 21 ---------------------
Training Set -----


Avg. Loss: 0.0701, Avg. Accuracy: 0.9725, Avg. Dice: 0.9301: 100%|███████████████████| 519/519 [16:47<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1874, Avg. Accuracy: 0.9433, Avg. Dice: 0.8575: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8575, Loss:  0.1874
Epoch 22 ---------------------
Training Set -----


Avg. Loss: 0.0659, Avg. Accuracy: 0.9741, Avg. Dice: 0.9344: 100%|███████████████████| 519/519 [16:44<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2133, Avg. Accuracy: 0.9426, Avg. Dice: 0.8537: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8537, Loss:  0.2133
Epoch 23 ---------------------
Training Set -----


Avg. Loss: 0.0603, Avg. Accuracy: 0.9761, Avg. Dice: 0.9391: 100%|███████████████████| 519/519 [16:43<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2025, Avg. Accuracy: 0.9456, Avg. Dice: 0.8599: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8599, Loss:  0.2025
Epoch 24 ---------------------
Training Set -----


Avg. Loss: 0.0593, Avg. Accuracy: 0.9765, Avg. Dice: 0.9403: 100%|███████████████████| 519/519 [16:46<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.2097, Avg. Accuracy: 0.9432, Avg. Dice: 0.8577: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8577, Loss:  0.2097
Epoch 25 ---------------------
Training Set -----


Avg. Loss: 0.0686, Avg. Accuracy: 0.9733, Avg. Dice: 0.9327: 100%|███████████████████| 519/519 [16:45<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1829, Avg. Accuracy: 0.9457, Avg. Dice: 0.8659: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8659, Loss:  0.1829
Epoch 26 ---------------------
Training Set -----


Avg. Loss: 0.0661, Avg. Accuracy: 0.9739, Avg. Dice: 0.9337: 100%|███████████████████| 519/519 [16:41<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2048, Avg. Accuracy: 0.9419, Avg. Dice: 0.8509: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8509, Loss:  0.2048
Epoch 27 ---------------------
Training Set -----


Avg. Loss: 0.0604, Avg. Accuracy: 0.9759, Avg. Dice: 0.9387: 100%|███████████████████| 519/519 [16:44<00:00,  1.93s/it]


Validation Set -----


Avg. Loss: 0.2099, Avg. Accuracy: 0.9443, Avg. Dice: 0.8667: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8667, Loss:  0.2099
Epoch 28 ---------------------
Training Set -----


Avg. Loss: 0.0628, Avg. Accuracy: 0.9752, Avg. Dice: 0.9376: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1998, Avg. Accuracy: 0.9420, Avg. Dice: 0.8542: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8542, Loss:  0.1998
Epoch 29 ---------------------
Training Set -----


Avg. Loss: 0.0615, Avg. Accuracy: 0.9755, Avg. Dice: 0.9391: 100%|███████████████████| 519/519 [16:48<00:00,  1.94s/it]


Validation Set -----


Avg. Loss: 0.1981, Avg. Accuracy: 0.9431, Avg. Dice: 0.8494: 100%|█████████████████████| 65/65 [01:31<00:00,  1.42s/it]


Dice:  0.8494, Loss:  0.1981
Epoch 30 ---------------------
Training Set -----


Avg. Loss: 0.0520, Avg. Accuracy: 0.9793, Avg. Dice: 0.9472: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.2401, Avg. Accuracy: 0.9424, Avg. Dice: 0.8491: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8491, Loss:  0.2401
Test Set -----


Validation Dataset: 100%|████████████████████████████████████████████████████████████| 259/259 [01:26<00:00,  2.98it/s]


0.17095472710045068


Test Dataset: 100%|██████████████████████████████████████████████████████████████████| 259/259 [01:22<00:00,  3.12it/s]


0.8764322970592537
Epoch 1 ---------------------
Training Set -----


Avg. Loss: 0.3268, Avg. Accuracy: 0.9099, Avg. Dice: 0.7853: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.2262, Avg. Accuracy: 0.9316, Avg. Dice: 0.8108: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8108, Loss:  0.2262
Epoch 2 ---------------------
Training Set -----


Avg. Loss: 0.2322, Avg. Accuracy: 0.9340, Avg. Dice: 0.8318: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1883, Avg. Accuracy: 0.9401, Avg. Dice: 0.8506: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8506, Loss:  0.1883
Epoch 3 ---------------------
Training Set -----


Avg. Loss: 0.1960, Avg. Accuracy: 0.9388, Avg. Dice: 0.8448: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1940, Avg. Accuracy: 0.9346, Avg. Dice: 0.8332: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8332, Loss:  0.1940
Epoch 4 ---------------------
Training Set -----


Avg. Loss: 0.1680, Avg. Accuracy: 0.9453, Avg. Dice: 0.8607: 100%|███████████████████| 519/519 [17:00<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.2042, Avg. Accuracy: 0.9378, Avg. Dice: 0.8443: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8443, Loss:  0.2042
Epoch 5 ---------------------
Training Set -----


Avg. Loss: 0.1639, Avg. Accuracy: 0.9467, Avg. Dice: 0.8632: 100%|███████████████████| 519/519 [16:51<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1733, Avg. Accuracy: 0.9455, Avg. Dice: 0.8697: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8697, Loss:  0.1733
Epoch 6 ---------------------
Training Set -----


Avg. Loss: 0.1424, Avg. Accuracy: 0.9510, Avg. Dice: 0.8752: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1959, Avg. Accuracy: 0.9364, Avg. Dice: 0.8404: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8404, Loss:  0.1959
Epoch 7 ---------------------
Training Set -----


Avg. Loss: 0.1365, Avg. Accuracy: 0.9517, Avg. Dice: 0.8775: 100%|███████████████████| 519/519 [16:57<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1830, Avg. Accuracy: 0.9395, Avg. Dice: 0.8552: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8552, Loss:  0.1830
Epoch 8 ---------------------
Training Set -----


Avg. Loss: 0.1223, Avg. Accuracy: 0.9571, Avg. Dice: 0.8923: 100%|███████████████████| 519/519 [16:57<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1916, Avg. Accuracy: 0.9401, Avg. Dice: 0.8597: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8597, Loss:  0.1916
Epoch 9 ---------------------
Training Set -----


Avg. Loss: 0.1112, Avg. Accuracy: 0.9593, Avg. Dice: 0.8977: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1739, Avg. Accuracy: 0.9420, Avg. Dice: 0.8496: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8496, Loss:  0.1739
Epoch 10 ---------------------
Training Set -----


Avg. Loss: 0.1232, Avg. Accuracy: 0.9552, Avg. Dice: 0.8878: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1673, Avg. Accuracy: 0.9443, Avg. Dice: 0.8598: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8598, Loss:  0.1673
Epoch 11 ---------------------
Training Set -----


Avg. Loss: 0.0991, Avg. Accuracy: 0.9630, Avg. Dice: 0.9062: 100%|███████████████████| 519/519 [16:51<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1938, Avg. Accuracy: 0.9290, Avg. Dice: 0.8134: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8134, Loss:  0.1938
Epoch 12 ---------------------
Training Set -----


Avg. Loss: 0.0946, Avg. Accuracy: 0.9641, Avg. Dice: 0.9094: 100%|███████████████████| 519/519 [16:54<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1857, Avg. Accuracy: 0.9425, Avg. Dice: 0.8543: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8543, Loss:  0.1857
Epoch 13 ---------------------
Training Set -----


Avg. Loss: 0.0927, Avg. Accuracy: 0.9653, Avg. Dice: 0.9110: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1662, Avg. Accuracy: 0.9427, Avg. Dice: 0.8560: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8560, Loss:  0.1662
Epoch 14 ---------------------
Training Set -----


Avg. Loss: 0.0932, Avg. Accuracy: 0.9651, Avg. Dice: 0.9117: 100%|███████████████████| 519/519 [16:57<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1760, Avg. Accuracy: 0.9429, Avg. Dice: 0.8554: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8554, Loss:  0.1760
Epoch 15 ---------------------
Training Set -----


Avg. Loss: 0.0888, Avg. Accuracy: 0.9659, Avg. Dice: 0.9149: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1752, Avg. Accuracy: 0.9451, Avg. Dice: 0.8732: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8732, Loss:  0.1752
Epoch 16 ---------------------
Training Set -----


Avg. Loss: 0.0803, Avg. Accuracy: 0.9687, Avg. Dice: 0.9208: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1806, Avg. Accuracy: 0.9449, Avg. Dice: 0.8639: 100%|█████████████████████| 65/65 [01:30<00:00,  1.39s/it]


Dice:  0.8639, Loss:  0.1806
Epoch 17 ---------------------
Training Set -----


Avg. Loss: 0.0849, Avg. Accuracy: 0.9673, Avg. Dice: 0.9175: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1881, Avg. Accuracy: 0.9345, Avg. Dice: 0.8425: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8425, Loss:  0.1881
Epoch 18 ---------------------
Training Set -----


Avg. Loss: 0.0762, Avg. Accuracy: 0.9704, Avg. Dice: 0.9236: 100%|███████████████████| 519/519 [16:54<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1754, Avg. Accuracy: 0.9445, Avg. Dice: 0.8629: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8629, Loss:  0.1754
Epoch 19 ---------------------
Training Set -----


Avg. Loss: 0.0814, Avg. Accuracy: 0.9682, Avg. Dice: 0.9217: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1673, Avg. Accuracy: 0.9473, Avg. Dice: 0.8696: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8696, Loss:  0.1673
Epoch 20 ---------------------
Training Set -----


Avg. Loss: 0.0703, Avg. Accuracy: 0.9726, Avg. Dice: 0.9305: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1873, Avg. Accuracy: 0.9466, Avg. Dice: 0.8683: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8683, Loss:  0.1873
Epoch 21 ---------------------
Training Set -----


Avg. Loss: 0.0683, Avg. Accuracy: 0.9733, Avg. Dice: 0.9324: 100%|███████████████████| 519/519 [16:51<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1814, Avg. Accuracy: 0.9473, Avg. Dice: 0.8619: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8619, Loss:  0.1814
Epoch 22 ---------------------
Training Set -----


Avg. Loss: 0.0612, Avg. Accuracy: 0.9760, Avg. Dice: 0.9393: 100%|███████████████████| 519/519 [16:54<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1750, Avg. Accuracy: 0.9469, Avg. Dice: 0.8661: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8661, Loss:  0.1750
Epoch 23 ---------------------
Training Set -----


Avg. Loss: 0.0613, Avg. Accuracy: 0.9759, Avg. Dice: 0.9384: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1939, Avg. Accuracy: 0.9467, Avg. Dice: 0.8738: 100%|█████████████████████| 65/65 [01:33<00:00,  1.44s/it]


Dice:  0.8738, Loss:  0.1939
Epoch 24 ---------------------
Training Set -----


Avg. Loss: 0.0633, Avg. Accuracy: 0.9749, Avg. Dice: 0.9355: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1995, Avg. Accuracy: 0.9471, Avg. Dice: 0.8661: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8661, Loss:  0.1995
Epoch 25 ---------------------
Training Set -----


Avg. Loss: 0.0676, Avg. Accuracy: 0.9736, Avg. Dice: 0.9321: 100%|███████████████████| 519/519 [16:58<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1791, Avg. Accuracy: 0.9465, Avg. Dice: 0.8575: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8575, Loss:  0.1791
Epoch 26 ---------------------
Training Set -----


Avg. Loss: 0.0657, Avg. Accuracy: 0.9741, Avg. Dice: 0.9336: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1726, Avg. Accuracy: 0.9467, Avg. Dice: 0.8675: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8675, Loss:  0.1726
Epoch 27 ---------------------
Training Set -----


Avg. Loss: 0.0582, Avg. Accuracy: 0.9771, Avg. Dice: 0.9410: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1933, Avg. Accuracy: 0.9448, Avg. Dice: 0.8576: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8576, Loss:  0.1933
Epoch 28 ---------------------
Training Set -----


Avg. Loss: 0.0561, Avg. Accuracy: 0.9776, Avg. Dice: 0.9440: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1787, Avg. Accuracy: 0.9438, Avg. Dice: 0.8612: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8612, Loss:  0.1787
Epoch 29 ---------------------
Training Set -----


Avg. Loss: 0.0534, Avg. Accuracy: 0.9787, Avg. Dice: 0.9458: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1992, Avg. Accuracy: 0.9459, Avg. Dice: 0.8560: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8560, Loss:  0.1992
Epoch 30 ---------------------
Training Set -----


Avg. Loss: 0.0526, Avg. Accuracy: 0.9792, Avg. Dice: 0.9476: 100%|███████████████████| 519/519 [16:50<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.2136, Avg. Accuracy: 0.9424, Avg. Dice: 0.8602: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]

Dice:  0.8602, Loss:  0.2136
Test Set -----



Validation Dataset: 100%|████████████████████████████████████████████████████████████| 259/259 [01:27<00:00,  2.98it/s]


0.16854038856779724


Test Dataset: 100%|██████████████████████████████████████████████████████████████████| 259/259 [01:23<00:00,  3.10it/s]


0.8767459968290348
Epoch 1 ---------------------
Training Set -----


Avg. Loss: 0.3191, Avg. Accuracy: 0.9121, Avg. Dice: 0.7824: 100%|███████████████████| 519/519 [16:59<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2428, Avg. Accuracy: 0.9349, Avg. Dice: 0.8412: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8412, Loss:  0.2428
Epoch 2 ---------------------
Training Set -----


Avg. Loss: 0.2381, Avg. Accuracy: 0.9291, Avg. Dice: 0.8200: 100%|███████████████████| 519/519 [16:54<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1883, Avg. Accuracy: 0.9389, Avg. Dice: 0.8403: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8403, Loss:  0.1883
Epoch 3 ---------------------
Training Set -----


Avg. Loss: 0.1946, Avg. Accuracy: 0.9394, Avg. Dice: 0.8452: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2135, Avg. Accuracy: 0.9307, Avg. Dice: 0.8205: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8205, Loss:  0.2135
Epoch 4 ---------------------
Training Set -----


Avg. Loss: 0.1705, Avg. Accuracy: 0.9437, Avg. Dice: 0.8570: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2166, Avg. Accuracy: 0.9172, Avg. Dice: 0.7765: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.7765, Loss:  0.2166
Epoch 5 ---------------------
Training Set -----


Avg. Loss: 0.1590, Avg. Accuracy: 0.9463, Avg. Dice: 0.8621: 100%|███████████████████| 519/519 [16:59<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1784, Avg. Accuracy: 0.9387, Avg. Dice: 0.8428: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8428, Loss:  0.1784
Epoch 6 ---------------------
Training Set -----


Avg. Loss: 0.1475, Avg. Accuracy: 0.9491, Avg. Dice: 0.8691: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2314, Avg. Accuracy: 0.9174, Avg. Dice: 0.8048: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8048, Loss:  0.2314
Epoch 7 ---------------------
Training Set -----


Avg. Loss: 0.1330, Avg. Accuracy: 0.9538, Avg. Dice: 0.8810: 100%|███████████████████| 519/519 [16:55<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1914, Avg. Accuracy: 0.9433, Avg. Dice: 0.8597: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8597, Loss:  0.1914
Epoch 8 ---------------------
Training Set -----


Avg. Loss: 0.1297, Avg. Accuracy: 0.9541, Avg. Dice: 0.8843: 100%|███████████████████| 519/519 [17:03<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.1769, Avg. Accuracy: 0.9384, Avg. Dice: 0.8458: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8458, Loss:  0.1769
Epoch 9 ---------------------
Training Set -----


Avg. Loss: 0.1197, Avg. Accuracy: 0.9571, Avg. Dice: 0.8912: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1973, Avg. Accuracy: 0.9359, Avg. Dice: 0.8385: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8385, Loss:  0.1973
Epoch 10 ---------------------
Training Set -----


Avg. Loss: 0.1102, Avg. Accuracy: 0.9599, Avg. Dice: 0.8967: 100%|███████████████████| 519/519 [16:52<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1729, Avg. Accuracy: 0.9441, Avg. Dice: 0.8600: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8600, Loss:  0.1729
Epoch 11 ---------------------
Training Set -----


Avg. Loss: 0.1045, Avg. Accuracy: 0.9611, Avg. Dice: 0.9013: 100%|███████████████████| 519/519 [16:57<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2108, Avg. Accuracy: 0.9302, Avg. Dice: 0.8348: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8348, Loss:  0.2108
Epoch 12 ---------------------
Training Set -----


Avg. Loss: 0.0974, Avg. Accuracy: 0.9630, Avg. Dice: 0.9062: 100%|███████████████████| 519/519 [16:54<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1669, Avg. Accuracy: 0.9434, Avg. Dice: 0.8650: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8650, Loss:  0.1669
Epoch 13 ---------------------
Training Set -----


Avg. Loss: 0.0978, Avg. Accuracy: 0.9631, Avg. Dice: 0.9042: 100%|███████████████████| 519/519 [16:54<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1656, Avg. Accuracy: 0.9454, Avg. Dice: 0.8618: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8618, Loss:  0.1656
Epoch 14 ---------------------
Training Set -----


Avg. Loss: 0.0853, Avg. Accuracy: 0.9672, Avg. Dice: 0.9169: 100%|███████████████████| 519/519 [17:01<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.1876, Avg. Accuracy: 0.9434, Avg. Dice: 0.8630: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8630, Loss:  0.1876
Epoch 15 ---------------------
Training Set -----


Avg. Loss: 0.0842, Avg. Accuracy: 0.9684, Avg. Dice: 0.9192: 100%|███████████████████| 519/519 [17:01<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.1702, Avg. Accuracy: 0.9436, Avg. Dice: 0.8570: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8570, Loss:  0.1702
Epoch 16 ---------------------
Training Set -----


Avg. Loss: 0.0893, Avg. Accuracy: 0.9658, Avg. Dice: 0.9136: 100%|███████████████████| 519/519 [16:57<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2085, Avg. Accuracy: 0.9410, Avg. Dice: 0.8522: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8522, Loss:  0.2085
Epoch 17 ---------------------
Training Set -----


Avg. Loss: 0.0842, Avg. Accuracy: 0.9677, Avg. Dice: 0.9171: 100%|███████████████████| 519/519 [16:53<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1873, Avg. Accuracy: 0.9433, Avg. Dice: 0.8624: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8624, Loss:  0.1873
Epoch 18 ---------------------
Training Set -----


Avg. Loss: 0.0773, Avg. Accuracy: 0.9705, Avg. Dice: 0.9245: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1829, Avg. Accuracy: 0.9449, Avg. Dice: 0.8657: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8657, Loss:  0.1829
Epoch 19 ---------------------
Training Set -----


Avg. Loss: 0.0867, Avg. Accuracy: 0.9663, Avg. Dice: 0.9147: 100%|███████████████████| 519/519 [17:05<00:00,  1.98s/it]


Validation Set -----


Avg. Loss: 0.1936, Avg. Accuracy: 0.9453, Avg. Dice: 0.8684: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8684, Loss:  0.1936
Epoch 20 ---------------------
Training Set -----


Avg. Loss: 0.0714, Avg. Accuracy: 0.9719, Avg. Dice: 0.9287: 100%|███████████████████| 519/519 [17:02<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.2131, Avg. Accuracy: 0.9412, Avg. Dice: 0.8486: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8486, Loss:  0.2131
Epoch 21 ---------------------
Training Set -----


Avg. Loss: 0.0719, Avg. Accuracy: 0.9720, Avg. Dice: 0.9276: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2029, Avg. Accuracy: 0.9453, Avg. Dice: 0.8656: 100%|█████████████████████| 65/65 [01:31<00:00,  1.40s/it]


Dice:  0.8656, Loss:  0.2029
Epoch 22 ---------------------
Training Set -----


Avg. Loss: 0.0642, Avg. Accuracy: 0.9749, Avg. Dice: 0.9357: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1890, Avg. Accuracy: 0.9453, Avg. Dice: 0.8676: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8676, Loss:  0.1890
Epoch 23 ---------------------
Training Set -----


Avg. Loss: 0.0746, Avg. Accuracy: 0.9714, Avg. Dice: 0.9280: 100%|███████████████████| 519/519 [16:56<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.1965, Avg. Accuracy: 0.9464, Avg. Dice: 0.8661: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8661, Loss:  0.1965
Epoch 24 ---------------------
Training Set -----


Avg. Loss: 0.0662, Avg. Accuracy: 0.9741, Avg. Dice: 0.9348: 100%|███████████████████| 519/519 [16:54<00:00,  1.95s/it]


Validation Set -----


Avg. Loss: 0.1853, Avg. Accuracy: 0.9408, Avg. Dice: 0.8622: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8622, Loss:  0.1853
Epoch 25 ---------------------
Training Set -----


Avg. Loss: 0.0644, Avg. Accuracy: 0.9747, Avg. Dice: 0.9358: 100%|███████████████████| 519/519 [17:00<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.1947, Avg. Accuracy: 0.9452, Avg. Dice: 0.8658: 100%|█████████████████████| 65/65 [01:30<00:00,  1.40s/it]


Dice:  0.8658, Loss:  0.1947
Epoch 26 ---------------------
Training Set -----


Avg. Loss: 0.0598, Avg. Accuracy: 0.9764, Avg. Dice: 0.9396: 100%|███████████████████| 519/519 [17:00<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.1849, Avg. Accuracy: 0.9448, Avg. Dice: 0.8713: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8713, Loss:  0.1849
Epoch 27 ---------------------
Training Set -----


Avg. Loss: 0.0561, Avg. Accuracy: 0.9777, Avg. Dice: 0.9431: 100%|███████████████████| 519/519 [16:59<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2276, Avg. Accuracy: 0.9434, Avg. Dice: 0.8650: 100%|█████████████████████| 65/65 [01:31<00:00,  1.42s/it]


Dice:  0.8650, Loss:  0.2276
Epoch 28 ---------------------
Training Set -----


Avg. Loss: 0.0596, Avg. Accuracy: 0.9764, Avg. Dice: 0.9398: 100%|███████████████████| 519/519 [17:01<00:00,  1.97s/it]


Validation Set -----


Avg. Loss: 0.2540, Avg. Accuracy: 0.9369, Avg. Dice: 0.8331: 100%|█████████████████████| 65/65 [01:31<00:00,  1.42s/it]


Dice:  0.8331, Loss:  0.2540
Epoch 29 ---------------------
Training Set -----


Avg. Loss: 0.0563, Avg. Accuracy: 0.9776, Avg. Dice: 0.9430: 100%|███████████████████| 519/519 [17:06<00:00,  1.98s/it]


Validation Set -----


Avg. Loss: 0.2071, Avg. Accuracy: 0.9469, Avg. Dice: 0.8727: 100%|█████████████████████| 65/65 [01:32<00:00,  1.42s/it]


Dice:  0.8727, Loss:  0.2071
Epoch 30 ---------------------
Training Set -----


Avg. Loss: 0.0525, Avg. Accuracy: 0.9791, Avg. Dice: 0.9469: 100%|███████████████████| 519/519 [16:58<00:00,  1.96s/it]


Validation Set -----


Avg. Loss: 0.2315, Avg. Accuracy: 0.9429, Avg. Dice: 0.8523: 100%|█████████████████████| 65/65 [01:31<00:00,  1.41s/it]


Dice:  0.8523, Loss:  0.2315
Test Set -----


Validation Dataset: 100%|████████████████████████████████████████████████████████████| 259/259 [01:26<00:00,  2.99it/s]


0.1598851496285095


Test Dataset: 100%|██████████████████████████████████████████████████████████████████| 259/259 [01:24<00:00,  3.08it/s]

0.8779404137439657
FINAL RESULTS:
POLAR=0.8777+/-0.0012
POLAR=0.1688+/-0.0052



