### 3D U-Net


In [2]:
# import libraries
import os
import random
import time
import cv2
import shutil
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse as ifilterfalse

import torchvision.transforms as transforms
from dataprepare3D import get_data
from dataprepare3D import load_labels




Change parameters here.

In [None]:
base_image_path = "../Training2/"
base_label_path = "../Label2/"
depth = 240
height = 240
width = 240
slices = 8
use_resize_2 = True

## Prepare data:
- take every 8 slices as one training sample

In [7]:
def data_prepare(path, is_label_data):
    heart = sitk.ReadImage(path)
    heartArray = sitk.GetArrayFromImage(heart)
    # print(heartArray.shape)

    # resize the image
    img_stack_sm = np.zeros((len(heartArray), height, depth))
    width = ((heartArray.shape[0]+slices-1)//slices)*slices

    for idx in range(len(heartArray)):
        img = heartArray[idx, :, :]
        if is_label_data:
            img_sm = cv2.resize(img, (depth, height), interpolation=cv2.INTER_NEAREST)
        else:
            img_sm = cv2.resize(img, (depth, height), interpolation=cv2.INTER_CUBIC)
        img_stack_sm[idx, :, :] = img_sm

    if (use_resize_2):
        img_stack_sm2 = np.zeros((width, height, depth))

        for idx in range(height):
            img = img_stack_sm[:, idx, :]
            if is_label_data:
                img_sm = cv2.resize(img, (depth, width), interpolation=cv2.INTER_NEAREST)

            else:
                img_sm = cv2.resize(img, (depth, width), interpolation=cv2.INTER_CUBIC)
            img_stack_sm2[:, idx, :] = img_sm
        img_stack_sm = img_stack_sm2

    # print(img_stack_sm.shape)
    img_stack_sm.resize((img_stack_sm.shape[0]//slices,slices,img_stack_sm.shape[1],img_stack_sm.shape[2]))
    return img_stack_sm.tolist()


def get_data(mini_dim, dim1, dim2):
    global depth
    global width
    global height
    global slices

    slices = mini_dim
    height = dim1
    depth = dim2

    image = []
    label = []
    heart_index = []

    for file in os.listdir(base_image_path):
        if (file[0] == '.'):
            continue
        image_path = base_image_path + file
        label_path = base_label_path + file[:-4] + "-label.nii"
        print(image_path)
        print(label_path)

        image += np.expand_dims(np.array(data_prepare(image_path, False)), axis=1).tolist()
        label += data_prepare(label_path, True)
        heart_index.append((len(label), int(image_path[-5])))

    return np.array(image).astype(np.float32), np.array(label), heart_index


def load_labels():
    labelshape = [0] * 10
    for file in os.listdir(base_label_path):
        if (file[0] == '.'):
            continue
        print(file)
        label = sitk.ReadImage(base_label_path + file)
        labelArray = sitk.GetArrayFromImage(label)
        labelshape[int(file[-11])] = labelArray
    return labelshape

## Define unit convolutions that will be used in the 3D U-Net model.

In [15]:
def conv_block_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.Conv3d(in_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
        act_fn,
    )
    return model


def conv_trans_block_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.ConvTranspose3d(in_dim,out_dim, kernel_size=3, stride=2, padding=1,output_padding=1),
        nn.BatchNorm3d(out_dim),
        act_fn,
    )
    return model


def maxpool_3d():
    pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
    return pool


def conv_block_2_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block_3d(in_dim,out_dim//2,act_fn),
        nn.Conv3d(out_dim//2,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
    )
    return model


def conv_block_3_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block_3d(in_dim,out_dim,act_fn),
        nn.Conv3d(out_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
    )
    return model

def conv_block_4_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.Conv3d(in_dim,out_dim, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm3d(out_dim),
        act_fn,
    )
    return model

## Difference between our U-Net and original paper

   
    UNet class is based on https://arxiv.org/abs/1505.04597 with 
    all the convolutions converted to 3D convolutions.
    
    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).
    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')

In [20]:
class UnetGenerator_3d(nn.Module):

    def __init__(self, in_dim, out_dim, num_filter):
        super(UnetGenerator_3d, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        act_fn = nn.ReLU()

        print("\n------Initiating U-Net------\n")

        self.down_1 = conv_block_2_3d(self.in_dim, self.num_filter, act_fn)
        self.pool_1 = maxpool_3d()
        self.down_2 = conv_block_2_3d(self.num_filter, self.num_filter * 2, act_fn)
        self.pool_2 = maxpool_3d()
        self.down_3 = conv_block_2_3d(self.num_filter * 2, self.num_filter * 4, act_fn)
        self.pool_3 = maxpool_3d()

        self.bridge = conv_block_2_3d(self.num_filter * 4, self.num_filter * 8, act_fn)

        self.trans_1 = conv_trans_block_3d(self.num_filter * 8, self.num_filter * 8, act_fn)
        self.up_1 = conv_block_3_3d(self.num_filter * 12, self.num_filter * 4, act_fn)
        self.trans_2 = conv_trans_block_3d(self.num_filter * 4, self.num_filter * 4, act_fn)
        self.up_2 = conv_block_3_3d(self.num_filter * 6, self.num_filter * 2, act_fn)
        self.trans_3 = conv_trans_block_3d(self.num_filter * 2, self.num_filter * 2, act_fn)
        self.up_3 = conv_block_3_3d(self.num_filter * 3, self.num_filter * 1, act_fn)

        self.out = conv_block_4_3d(self.num_filter, out_dim, nn.LogSoftmax())
        self.out_lovasz = conv_block_4_3d(self.num_filter, out_dim, nn.Softmax())
        self.reset_params()

    @staticmethod
    def weight_init(m):
        if (isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d)):
            nn.init.xavier_normal(m.weight)
            nn.init.constant(m.bias, 0)

    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)

    def forward(self, x):
        down_1 = self.down_1(x)
        pool_1 = self.pool_1(down_1)
        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)
        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)

        bridge = self.bridge(pool_3)

        trans_1 = self.trans_1(bridge)
        concat_1 = torch.cat([trans_1, down_3], dim=1)
        up_1 = self.up_1(concat_1)
        trans_2 = self.trans_2(up_1)
        concat_2 = torch.cat([trans_2, down_2], dim=1)
        up_2 = self.up_2(concat_2)
        trans_3 = self.trans_3(up_2)
        concat_3 = torch.cat([trans_3, down_1], dim=1)
        up_3 = self.up_3(concat_3)
        if(args.use_lovasz):
            out = self.out_lovasz(up_3)
        else:
            out = self.out(up_3)

        return out

Define the dataset that will be passed to the dataloader to generate each training sample.

In [8]:
class MyCustomDataset(Dataset):
    def __init__(self, type, dev_heart):
        if(dev_heart == 0):
            from_num = 0
        else:
            from_num = heart_index[dev_heart-1][0]
        to_num = heart_index[dev_heart][0]

        if(type == 'Train'):
            self.image = np.concatenate((total_image[:from_num,:,:,:,:],total_image[to_num:,:,:,:,:]))
            self.label = np.concatenate((total_label[:from_num,:,:,:],total_label[to_num:,:,:,:]))
            print(self.image.shape)
            print(self.label.shape)
        else:
            self.image = total_image[from_num:to_num, :, :, :, :]
            self.label = total_label[from_num:to_num, :, :, :]
            print(self.image.shape)
            print(self.label.shape)

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        image = self.image[idx]
        mask = self.label[idx]
        return (image, mask)

def get_loss(dl, model):
    loss = 0
    if (args.use_lovasz):
        for X, y in dl:
            softmax_output_z = model(data)
            vprobas, vlabels = flatten_probas(softmax_output_z, target.long())
            loss += lovasz_softmax_flat(vprobas, vlabels).item()
        loss = loss / len(dl)
        return loss
    else:
        for X, y in dl:
            X, y = Variable(X).cuda(), Variable(y).cuda()
            output = model(X)
            loss += nn.NLLLoss()(output, y.long()).item()
        loss = loss / len(dl)
        return loss


Define binary dice score and binary Jaccard score to evaluate accuracy
- binary dice/Jaccard score uses one vs. all strategy. For example: if we are looking at BlooodPool(label=2), Ventricular Myocardium and background are both regarded as 0.

In [1]:
'''
main_class: the class that you want to predict as one, must be a single value
'''


def binary_vector(x, main_class):
    length = len(x)
    binary = np.zeros(length)
    binary = main_class

    return (binary == x).astype(int)


'''
classes: a list of labels that we want for binary comparison
e.g. [1, 2] will return a list of two scores. The first index
is the score of regarding 1 as 1 and 2 as 0. The second is the
score of regarding 2 as 1 and 1 as 0. 
*WARNING: label and target must be of the same dimension. 
'''


def binary_dice_score(label, target, classes):
    scores = []
    smooth = 1

    for cls in classes:
        label_binary = binary_vector(label.flatten(), cls)
        #print("label_binary")
        #print(label_binary)
        target_binary = binary_vector(target.flatten(), cls)
        #print("target_binary")
        #print(target_binary)
        intersection = np.sum(label_binary * target_binary)
        normalization = np.sum(label_binary + target_binary)
        score = ((2. * intersection + smooth).sum() /
                 (normalization + smooth).sum())
        scores.append(score)
        #print("--------------------")
    return scores

def get_dice_score(dl, model):

    #batch_num = 0
    score = 0
    #img_sm = cv2.resize(img, (height, depth), interpolation=cv2.INTER_NEAREST)


    for X, y in dl:
        X = Variable(X).cuda()
        output = model(X).cpu()
        #print(output.shape)
        predicted = np.argmax(output.data.numpy(),axis=1)
        predicted.resize((predicted.shape[0]*predicted.shape[1],predicted.shape[2],predicted.shape[3]))
        #print(predicted.shape)

        predicted_origin = [0]*predicted.shape[0]
        for idx in range(len(predicted)):
            img = predicted[idx, :, :]
            img_sm = cv2.resize(img, (label_original[heart_index[dev_heart][1]].shape[2], label_original[heart_index[dev_heart][1]].shape[1]), interpolation=cv2.INTER_NEAREST)
            predicted_origin[idx] = img_sm

        predicted_origin = np.array(predicted_origin)
        predicted_origin2 = np.zeros((label_original[heart_index[dev_heart][1]].shape[0], label_original[heart_index[dev_heart][1]].shape[1], label_original[heart_index[dev_heart][1]].shape[2]))
        for idx in range(label_original[heart_index[dev_heart][1]].shape[1]):
            img = predicted_origin[:, idx, :]
            # shape 2 and shape 0 has confuse, need to check again
            img_sm = cv2.resize(img, (label_original[heart_index[dev_heart][1]].shape[2], label_original[heart_index[dev_heart][1]].shape[0]), interpolation=cv2.INTER_NEAREST)
            predicted_origin2[:, idx, :] = img_sm

        ground_truth = label_original[heart_index[dev_heart][1]].astype("int64")
        score = binary_dice_score(ground_truth,predicted_origin2.astype("int64"), [1,2])
    return score


def binary_jaccard_index(label, target, classes):
    scores = []
    assert (len(label.flatten()) == len(target.flatten()))
    for cls in classes:
        label_binary = binary_vector(label.flatten(), cls)
        target_binary = binary_vector(target.flatten(), cls)
        length = len(label_binary)

        union = (label_binary != target_binary).astype(int).sum() + length
        intersection = (label_binary == target_binary).astype(int).sum()

        scores.append(intersection / union)
    return scores


def get_jaccard_score(dl, model):

    #batch_num = 0
    score = 0
    #img_sm = cv2.resize(img, (height, depth), interpolation=cv2.INTER_NEAREST)


    for X, y in dl:
        X = Variable(X).cuda()
        output = model(X).cpu()
        #print(output.shape)
        predicted = np.argmax(output.data.numpy(),axis=1)
        predicted.resize((predicted.shape[0]*predicted.shape[1],predicted.shape[2],predicted.shape[3]))
        #print(predicted.shape)

        predicted_origin = [0]*predicted.shape[0]
        for idx in range(len(predicted)):
            img = predicted[idx, :, :]
            img_sm = cv2.resize(img, (label_original[heart_index[dev_heart][1]].shape[2], label_original[heart_index[dev_heart][1]].shape[1]), interpolation=cv2.INTER_NEAREST)
            predicted_origin[idx] = img_sm

        predicted_origin = np.array(predicted_origin)
        predicted_origin2 = np.zeros((label_original[heart_index[dev_heart][1]].shape[0], label_original[heart_index[dev_heart][1]].shape[1], label_original[heart_index[dev_heart][1]].shape[2]))
        for idx in range(label_original[heart_index[dev_heart][1]].shape[1]):
            img = predicted_origin[:, idx, :]
            # shape 2 and shape 0 has confuse, need to check again
            img_sm = cv2.resize(img, (label_original[heart_index[dev_heart][1]].shape[2], label_original[heart_index[dev_heart][1]].shape[0]), interpolation=cv2.INTER_NEAREST)
            predicted_origin2[:, idx, :] = img_sm

        ground_truth = label_original[heart_index[dev_heart][1]].astype("int64")
        score = jaccard_index(ground_truth,predicted_origin2.astype("int64"))

    return score

Lovasz softmax functions
- source:https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py

In [None]:
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

def isnan(x):
    return x != x

def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n



def lovasz_softmax_flat(probas, labels, classes='present'):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        return probas * 0.
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        fg = (labels == c).float() # foreground for class c
        if (classes is 'present' and fg.sum() == 0):
            continue
        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]
        errors = (Variable(fg) - class_pred).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    B, C, D, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 4, 1).contiguous().view(-1, C)  # B * D * H * W, C = P, C
    labels = labels.view(-1)
    return probas, labels


def get_accuracy(dl, model):

    total_num = 0
    correct_num = 0

    for X, y in dl:
        X = Variable(X).cuda()
        output = model(X).cpu()
        #print(output.shape)
        #print(y.shape)
        #print(y.type())
        #print(np.argmax(output.data.numpy()).dtype)
        correct_num += (np.argmax(output.data.numpy(),axis=1) == y.data.numpy().astype("int64")).sum().item()
        total_num += y.shape[0]*y.shape[1]*y.shape[2]*y.shape[3]
    return correct_num/total_num



Define parameters that will be used in the training routine

In [11]:
args = {
    "batch_size": 4,
    "test-batch-size": 1000,
    'slices-depth':8,
    "epochs":50,
    "figuresize1":200,
    "figuresize2":160,
    "lr":0.001,
    "seed":1,
    "channel-base":8,
    "log_interval":1,
    "save_model":True,
    "use-lovasz":False,
    "test_model":"",
    "load_mode":""
}

label_original = load_labels()
total_image, total_label, heart_index = get_data(args.slices_depth,args.figuresize1,args.figuresize2)

dev_heart = 7

timeStr = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
os.mkdir(timeStr + "model")

../../Training2/training_axial_crop_pat4.nii
../../Label2/training_axial_crop_pat4-label.nii
../../Training2/training_axial_crop_pat3.nii
../../Label2/training_axial_crop_pat3-label.nii
../../Training2/training_axial_crop_pat2.nii
../../Label2/training_axial_crop_pat2-label.nii
../../Training2/training_axial_crop_pat5.nii
../../Label2/training_axial_crop_pat5-label.nii
../../Training2/training_axial_crop_pat6.nii
../../Label2/training_axial_crop_pat6-label.nii
../../Training2/training_axial_crop_pat1.nii
../../Label2/training_axial_crop_pat1-label.nii
../../Training2/training_axial_crop_pat8.nii
../../Label2/training_axial_crop_pat8-label.nii
../../Training2/training_axial_crop_pat9.nii
../../Label2/training_axial_crop_pat9-label.nii
../../Training2/training_axial_crop_pat0.nii
../../Label2/training_axial_crop_pat0-label.nii
../../Training2/training_axial_crop_pat7.nii
../../Label2/training_axial_crop_pat7-label.nii


The training routine is defined below. Here are a few notes:
1. An entire subject is left out for cross validation


In [50]:


while(dev_heart < 10):

    print("We are using heart "+str(heart_index[dev_heart][1]))
    train_loader = torch.utils.data.DataLoader(MyCustomDataset('Train', dev_heart), batch_size=args.batch_size, shuffle=True)
    if (dev_heart == 0):
        dev_loader = torch.utils.data.DataLoader(MyCustomDataset('Dev', dev_heart), batch_size=heart_index[0][0], shuffle=False)
    else:
        dev_loader = torch.utils.data.DataLoader(MyCustomDataset('Dev', dev_heart), batch_size=heart_index[dev_heart][0]-heart_index[dev_heart-1][0], shuffle=False)

    model = UnetGenerator_3d(1, 3, args.channel_base)
    if(args.load_model is not None):
        exist_dict = torch.load(args.load_model)
        total_dict = model.state_dict()
        for k, v in exist_dict.items():
            total_dict[k] = v
        model.load_state_dict(total_dict)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model = model.to(device)
    summary(model, input_size=(1, args.slices_depth, args.figuresize1, args.figuresize2))
    model.train()

    best_dice = [0.0,0.0]
    best_jaccard = 0
    optim = torch.optim.Adam(model.parameters(),lr=args.lr)

    os.mkdir(timeStr + "model/dice"+str(heart_index[dev_heart][1]))
    os.mkdir(timeStr + "model/jaccard"+str(heart_index[dev_heart][1]))

    #dev_acc = get_accuracy(dev_loader, model)
    #dev_dice = get_dice_score(dev_loader, model)
    for epoch in range(args.epochs):

        for batch_idx, (data, label) in enumerate(train_loader):

            data, target = data.to(device), label.to(device)
            if (args.use_lovasz):
                softmax_output_z = model(data)
                vprobas, vlabels = flatten_probas(softmax_output_z, target.long())
                loss = lovasz_softmax_flat(vprobas, vlabels)
            else:
                logsoftmax_output_z = model(data)
                loss = nn.NLLLoss()(logsoftmax_output_z, target.long())

            optim.zero_grad()
            loss.backward()
            optim.step()

        if (epoch + 1) % args.log_interval == 0:

            print("Epoch : "+str(epoch))
            model.eval()

            train_loss = get_loss(train_loader, model)
            print(train_loss)
            train_acc = get_accuracy(train_loader, model)
            print("Training accuracy : " + str(train_acc))
            dev_dice = get_dice_score(dev_loader, model)
            print("Dev dice score : " + str(dev_dice))
            #dev_jaccard = get_jaccard_score(dev_loader, model)
            #print("Dev jaccard score : " + str(dev_jaccard))
            dev_acc = get_accuracy(dev_loader, model)
            print("Dev accuracy : " + str(dev_acc))
            #if(train_acc < 0.01):
            #    print("Bad initialization")
            #    exit(0)
            if(args.save_model):
                if(dev_dice[0] > best_dice[0] or dev_dice[1] > best_dice[1]):
                    print("Best model found")
                    torch.save(model.state_dict(), timeStr + "model/dice"+str(heart_index[dev_heart][1])+"/" + str(epoch) + ":" + str(dev_dice) + ".pt")
                if(dev_dice[0] > best_dice[0]):
                    best_dice[0] = dev_dice[0]
                if(dev_dice[1] > best_dice[1]):
                    best_dice[1] = dev_dice[1]

            #if(args.save_model and (dev_jaccard > best_jaccard)):
            #    torch.save(model.state_dict(), timeStr + "model/jaccard"+str(heart_index[dev_heart][1])+"/" + str(epoch) + ":" + str(dev_jaccard) + ".pt")
            #    best_jaccard = dev_jaccard

            model.train()

    print("Done")

    dev_heart += 1

We are using heart 4
(1378, 1, 240, 240)
(1378, 240, 240)
(119, 1, 240, 240)
(119, 240, 240)
cuda




Epoch : 0
0.40772685905297595
Training accuracy : 0.777442259816965
Dev dice score : 0.7793324533885955
Dev jaccard score : 0.6384472804008483
Dev accuracy : 0.7793675595238095
Epoch : 1
0.3381048151332399
Training accuracy : 0.877344619315433
Dev dice score : 0.7793324533885955
Dev jaccard score : 0.6384472804008483
Dev accuracy : 0.8218246673669468
Epoch : 2
0.25656376422747323
Training accuracy : 0.9041406375987744
Dev dice score : 0.8602283984135417
Dev jaccard score : 0.7547371816749496
Dev accuracy : 0.8618782096171802
Epoch : 3
0.20420077044678772
Training accuracy : 0.9225048757256894
Dev dice score : 0.8418785132735064
Dev jaccard score : 0.7269341657134601
Dev accuracy : 0.8472862686741364
Epoch : 4
0.28036202850981035
Training accuracy : 0.8995257821319143
Dev dice score : 0.8968727502926784
Dev jaccard score : 0.8130271930409311
Dev accuracy : 0.898969712885154
Epoch : 5
0.14960730340385783
Training accuracy : 0.9394048338977584
Dev dice score : 0.8809367628087111
Dev jacca

Epoch : 46
0.031244004725654057
Training accuracy : 0.9867031023222061
Dev dice score : 0.8972926512536016
Dev jaccard score : 0.8137175783594459
Dev accuracy : 0.8971921685340803
Epoch : 47
0.0387804584874623
Training accuracy : 0.9836972237340752
Dev dice score : 0.9037009367308365
Dev jaccard score : 0.8243194636979256
Dev accuracy : 0.9038993930905695
Epoch : 48
0.030514749547169692
Training accuracy : 0.9869135900459604
Dev dice score : 0.9018731952384936
Dev jaccard score : 0.8212830300012908
Dev accuracy : 0.9015740254435107
Epoch : 49
0.03368544667837736
Training accuracy : 0.9856060388445412
Dev dice score : 0.9020128069793082
Dev jaccard score : 0.8215146109032883
Dev accuracy : 0.9029119981325864
Done
We are using heart 3
(1349, 1, 240, 240)
(1349, 240, 240)
(148, 1, 240, 240)
(148, 240, 240)
cuda
Epoch : 0
0.8930123582394165
Training accuracy : 0.7746394319866567
Dev dice score : 0.8044554502506215
Dev jaccard score : 0.6728777333260262
Dev accuracy : 0.8044527730855856
Epo

Epoch : 42
0.6644856704939046
Training accuracy : 0.7746469993204843
Dev dice score : 0.8044659802565314
Dev jaccard score : 0.6728924676778659
Dev accuracy : 0.8044682573198199
Epoch : 43
0.6629068696287257
Training accuracy : 0.7746469993204843
Dev dice score : 0.8044659802565314
Dev jaccard score : 0.6728924676778659
Dev accuracy : 0.8044682573198199
Epoch : 44
0.6641430516214766
Training accuracy : 0.7746469993204843
Dev dice score : 0.8044659802565314
Dev jaccard score : 0.6728924676778659
Dev accuracy : 0.8044682573198199
Epoch : 45
0.6642470975246655
Training accuracy : 0.7746469993204843
Dev dice score : 0.8044659802565314
Dev jaccard score : 0.6728924676778659
Dev accuracy : 0.8044682573198199
Epoch : 46
0.6641520112752914
Training accuracy : 0.7746469993204843
Dev dice score : 0.8044659802565314
Dev jaccard score : 0.6728924676778659
Dev accuracy : 0.8044682573198199
Epoch : 47
0.6637686172709663
Training accuracy : 0.7746469993204843
Dev dice score : 0.8044659802565314
Dev j

Epoch : 38
0.05073627191694211
Training accuracy : 0.9798244073864113
Dev dice score : 0.9221272836929608
Dev jaccard score : 0.8555066141338601
Dev accuracy : 0.925364383780332
Epoch : 39
0.033038083489230215
Training accuracy : 0.9861799519190392
Dev dice score : 0.9254311740291027
Dev jaccard score : 0.8612115927515087
Dev accuracy : 0.9284859913793103
Epoch : 40
0.035660669495973435
Training accuracy : 0.9854289231124549
Dev dice score : 0.941841198837798
Dev jaccard score : 0.8900754433439675
Dev accuracy : 0.9460313896871009
Epoch : 41
0.030732603294893066
Training accuracy : 0.9871034239313009
Dev dice score : 0.9301352644751564
Dev jaccard score : 0.8693951602334373
Dev accuracy : 0.9333434107598978
Epoch : 42
0.03129366440895461
Training accuracy : 0.9868002540522381
Dev dice score : 0.9273782976573267
Dev jaccard score : 0.8645902397898779
Dev accuracy : 0.9305104565772669
Epoch : 43
0.0370478995631131
Training accuracy : 0.9847262508398421
Dev dice score : 0.9380515290929472

Epoch : 34
0.034901240292635416
Training accuracy : 0.9856300403225806
Dev dice score : 0.9346358584201144
Dev jaccard score : 0.8772922876859008
Dev accuracy : 0.9439495469173442
Epoch : 35
0.037420910986664585
Training accuracy : 0.9846352864257731
Dev dice score : 0.9350623298466829
Dev jaccard score : 0.8780440802437521
Dev accuracy : 0.9450550474254743
Epoch : 36
0.03246328229932282
Training accuracy : 0.9865130735808952
Dev dice score : 0.9391970073045558
Dev jaccard score : 0.885364113599446
Dev accuracy : 0.9484269139566396
Epoch : 37
0.03082733140642691
Training accuracy : 0.9871603057014253
Dev dice score : 0.9398884826522579
Dev jaccard score : 0.8865938747344715
Dev accuracy : 0.9487122078252033
Epoch : 38
0.03093935194575858
Training accuracy : 0.9871630668083687
Dev dice score : 0.9380993980405117
Dev jaccard score : 0.8834153484185375
Dev accuracy : 0.9476928777100271
Epoch : 39
0.0315863484974036
Training accuracy : 0.9868947575435526
Dev dice score : 0.9386306384591563

Epoch : 30
0.03429392832331359
Training accuracy : 0.9854183554409236
Dev dice score : 0.9174252826713539
Dev jaccard score : 0.8474474710947987
Dev accuracy : 0.9241402529761905
Epoch : 31
0.035080575500505265
Training accuracy : 0.9851112032260706
Dev dice score : 0.9122880891001885
Dev jaccard score : 0.8387220912080857
Dev accuracy : 0.9209365079365079
Epoch : 32
0.04283907721366952
Training accuracy : 0.9821525219028904
Dev dice score : 0.9037359788198573
Dev jaccard score : 0.8243779542483918
Dev accuracy : 0.9130731646825396
Epoch : 33
0.03433486489880392
Training accuracy : 0.9854215538770162
Dev dice score : 0.9130733034931884
Dev jaccard score : 0.840050416033377
Dev accuracy : 0.9215648561507936
Epoch : 34
0.03605910924756352
Training accuracy : 0.9846724417628756
Dev dice score : 0.9153190566171878
Dev jaccard score : 0.843860108550126
Dev accuracy : 0.9236908482142857
Epoch : 35
0.03468052832519307
Training accuracy : 0.9853423861868501
Dev dice score : 0.9159997089578543


Epoch : 26
0.045023807126666066
Training accuracy : 0.9810846267553585
Dev dice score : 0.949168237076157
Dev jaccard score : 0.903254195433598
Dev accuracy : 0.9560689139660494
Epoch : 27
0.04489063579797921
Training accuracy : 0.9814620801511046
Dev dice score : 0.9399195106241702
Dev jaccard score : 0.8866491506185662
Dev accuracy : 0.9471344521604939
Epoch : 28
0.038927454646995464
Training accuracy : 0.9834871581670362
Dev dice score : 0.9482437095029782
Dev jaccard score : 0.901581173209271
Dev accuracy : 0.9556475453317901
Epoch : 29
0.03609908480131257
Training accuracy : 0.9846404979674797
Dev dice score : 0.9476095733065026
Dev jaccard score : 0.9004353418921986
Dev accuracy : 0.9546449411651234
Epoch : 30
0.0360223289900583
Training accuracy : 0.9846954299088445
Dev dice score : 0.9480836491737715
Dev jaccard score : 0.9012918271756591
Dev accuracy : 0.9555035927854938
Epoch : 31
0.034716148083276804
Training accuracy : 0.9852348048575182
Dev dice score : 0.9497580447900827


Epoch : 22
0.04033819439921998
Training accuracy : 0.9836079290892799
Dev dice score : 0.9383466458042248
Dev jaccard score : 0.8838540214018153
Dev accuracy : 0.9456148182957393
Epoch : 23
0.04471298625905755
Training accuracy : 0.9818602481467905
Dev dice score : 0.9388243648287716
Dev jaccard score : 0.8847020933977455
Dev accuracy : 0.9442362416457811
Epoch : 24
0.039888772498980124
Training accuracy : 0.9837359481915934
Dev dice score : 0.9375925857462205
Dev jaccard score : 0.8825169262940477
Dev accuracy : 0.9443308792815371
Epoch : 25
0.03825815670776708
Training accuracy : 0.9842796157135875
Dev dice score : 0.9357941542115856
Dev jaccard score : 0.879335606210025
Dev accuracy : 0.943047070802005
Epoch : 26
0.04512412648353059
Training accuracy : 0.9813705477150537
Dev dice score : 0.9330307438770101
Dev jaccard score : 0.8744681951817094
Dev accuracy : 0.940672123015873
Epoch : 27
0.03903236248478555
Training accuracy : 0.9838610271057348
Dev dice score : 0.9358316357327697
D

Epoch : 18
0.04299514115388904
Training accuracy : 0.9822866805142195
Dev dice score : 0.935347992492408
Dev jaccard score : 0.8785480551113394
Dev accuracy : 0.9447426470588235
Epoch : 19
0.03958448793301137
Training accuracy : 0.9834568245701059
Dev dice score : 0.9381018348631243
Dev jaccard score : 0.8834197340360975
Dev accuracy : 0.9472575118010167
Epoch : 20
0.03924090826296292
Training accuracy : 0.9836981362640542
Dev dice score : 0.936204339430519
Dev jaccard score : 0.8800602716551232
Dev accuracy : 0.9442105800653595
Epoch : 21
0.03766358642384321
Training accuracy : 0.9842546606316137
Dev dice score : 0.9390638155675864
Dev jaccard score : 0.8851274842840967
Dev accuracy : 0.9464741058460421
Epoch : 22
0.056396046852959056
Training accuracy : 0.9772899486400463
Dev dice score : 0.9394299665964858
Dev jaccard score : 0.8857783056074476
Dev accuracy : 0.9459519108569354
Epoch : 23
0.03686820148556892
Training accuracy : 0.9845272972470238
Dev dice score : 0.9385222986782086


Epoch : 14
0.06980580831589424
Training accuracy : 0.9737666185267125
Dev dice score : 0.9465173067353737
Dev jaccard score : 0.8984649408091583
Dev accuracy : 0.9502934150906225
Epoch : 15
0.0655948990230671
Training accuracy : 0.975144189405113
Dev dice score : 0.9436951583334244
Dev jaccard score : 0.8933927763732306
Dev accuracy : 0.9484425482663514
Epoch : 16
0.06086075113355498
Training accuracy : 0.9767247931006228
Dev dice score : 0.947004256153739
Dev jaccard score : 0.8993428709996513
Dev accuracy : 0.9501467691095351
Epoch : 17
0.05573608821318389
Training accuracy : 0.9786473953212061
Dev dice score : 0.948940723148213
Dev jaccard score : 0.9028422133088365
Dev accuracy : 0.9526225128053586
Epoch : 18
0.06286756384425458
Training accuracy : 0.9760963746517535
Dev dice score : 0.9481718556455311
Dev jaccard score : 0.9014512692546519
Dev accuracy : 0.9527614016942474
Epoch : 19
0.050721372513424685
Training accuracy : 0.9804145413389053
Dev dice score : 0.9469594729939337
De

Epoch : 10
0.13480273624004563
Training accuracy : 0.9598605306484296
Dev dice score : 0.951673615298071
Dev jaccard score : 0.9078027700250552
Dev accuracy : 0.9597679749846532
Epoch : 11
0.11789813084035297
Training accuracy : 0.9637226522922999
Dev dice score : 0.9493908187840782
Dev jaccard score : 0.9036574356211586
Dev accuracy : 0.9596999693063228
Epoch : 12
0.09242226305703624
Training accuracy : 0.9707780310705842
Dev dice score : 0.9540287036373506
Dev jaccard score : 0.912098341973992
Dev accuracy : 0.9634965853284223
Epoch : 13
0.09406197852084248
Training accuracy : 0.9696740174349882
Dev dice score : 0.9487812417844228
Dev jaccard score : 0.9025535492883751
Dev accuracy : 0.9572428445365255
Epoch : 14
0.07930975366598929
Training accuracy : 0.9728278822610604
Dev dice score : 0.9464163464145812
Dev jaccard score : 0.8982830379134813
Dev accuracy : 0.9571822245242481
Epoch : 15
0.0703266045690081
Training accuracy : 0.9755798162571766
Dev dice score : 0.9525556876469172
De