In [None]:
## Importing necessary libraries

import torch as tor
import torch.nn as nn
import torchvision as tv
import torch.optim
import torch.utils.data
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF

import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
from dist_utils.nn_layers.cnn_utils import *
import math
import time
from PIL import Image as im

from sklearn.metrics import f1_score 
import dist_utils.postprocessing as pp
from sklearn.metrics import jaccard_score 
import random

import statistics as stats
import copy

import albumentations as A
from scipy import ndimage

In [None]:
def calc_total_mean(datafiles,num_chn = 3,verbose = False):

    """
        Function to calculate the mean of the training dataset 
        
        Arguments:
            datafiles : Directory of training dataset
            num_chn   : The number of channels in each image : Default = 3 channels
            verbose   : flag to print the ser. number and filename of each image
            
        Returns: 
            The mean image of the dataset
    """
    
    img_sum = 0
    num_files = len(datafiles)

    for e,file in enumerate(datafiles):

        if(num_chn == 3):
            img = cv2.resize(cv2.imread(file),(256,256))
            img_sum += img
        elif(num_chn == 1):
            img = cv2.imread(file,0)
            img_sum += img
        else:
            assert "Incorrect number of channels"

        if(verbose):
            print(e,file)


    return np.float32(img_sum) / num_files

In [None]:
class dice(nn.Module):
    """
    This class implements the volume-wise seperable convolutions
    """
    def __init__(self, channel_in, channel_out, height, width, kernel_size=3, dilation=[1, 1, 1], shuffle=True):
        
        """
            Constructor to initialize the DiCE block instance 
            
            Arguements:
                channel_in: Number of input channels
                channel_out: Number of output channels
                height: Height of the input volume
                width: Width of the input volume
                kernel_size: Size of the kernel
                dilation: Rates of dilation for each dimension
                shuffle: shuffling the feature maps
                
            Returns:
                The convolution operation result performed by successive dimconv and dimfuse operations
        """
        
        super().__init__()

        assert len(dilation) == 3
        padding_1 = int((kernel_size - 1) / 2) *dilation[0] 
        padding_2 = int((kernel_size - 1) / 2) *dilation[1] 
        padding_3 = int((kernel_size - 1) / 2) *dilation[2] 
        self.conv_channel = nn.Conv2d(channel_in, channel_in, kernel_size=kernel_size, stride=1, groups=channel_in,
                                      padding=padding_1, bias=False, dilation=dilation[0])
        self.conv_width = nn.Conv2d(width, width, kernel_size=kernel_size, stride=1, groups=width,
                               padding=padding_2, bias=False, dilation=dilation[1])
        self.conv_height = nn.Conv2d(height, height, kernel_size=kernel_size, stride=1, groups=height,
                               padding=padding_3, bias=False, dilation=dilation[2])

        self.br_act = BR(3*channel_in)
        self.weight_avg_layer = CBR(3*channel_in, channel_in, kSize=1, stride=1, groups=channel_in)

        # project from channel_in to Channel_out
        groups_proj = math.gcd(channel_in, channel_out)
        self.proj_layer = CBR(channel_in, channel_out, kSize=3, stride=1, groups=groups_proj)
        self.linear_comb_layer = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Conv2d(channel_in, channel_in // 4, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel_in //4, channel_out, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

        self.vol_shuffle = Shuffle(3)

        self.width = width
        self.height = height
        self.channel_in = channel_in
        self.channel_out = channel_out
        self.shuffle = shuffle
        self.ksize=kernel_size
        self.dilation = dilation

    def forward(self, x):
        
        """
            Method that implements the forward operation of the DiCE block
            
            Arguments:
                x: input to the DiCE block
                
            Returns:
                convolution output
        """
        
        bsz, channels, height, width = x.size()
        
        # process across channel. Input: C x H x W, Output: C x H x W
        out_ch_wise = self.conv_channel(x)

        # process across height. Input: H x C x W, Output: C x H x W
        x_h_wise = x.clone()
        if height != self.height:
            if height < self.height:
                x_h_wise = F.interpolate(x_h_wise, mode='bilinear', size=(self.height, width), align_corners=True)
            else:
                x_h_wise = F.adaptive_avg_pool2d(x_h_wise, output_size=(self.height, width))

        x_h_wise = x_h_wise.transpose(1, 2).contiguous()
        out_h_wise = self.conv_height(x_h_wise).transpose(1, 2).contiguous()

        h_wise_height = out_h_wise.size(2)
        if height != h_wise_height:
            if h_wise_height < height:
                out_h_wise = F.interpolate(out_h_wise, mode='bilinear', size=(height, width), align_corners=True)
            else:
                out_h_wise = F.adaptive_avg_pool2d(out_h_wise, output_size=(height, width))

        # process across width: Input: W x H x C, Output: C x H x W
        x_w_wise = x.clone()
        if width != self.width:
            if width < self.width:
                x_w_wise = F.interpolate(x_w_wise, mode='bilinear', size=(height, self.width), align_corners=True)
            else:
                x_w_wise = F.adaptive_avg_pool2d(x_w_wise, output_size=(height, self.width))

        x_w_wise = x_w_wise.transpose(1, 3).contiguous()
        out_w_wise = self.conv_width(x_w_wise).transpose(1, 3).contiguous()
        w_wise_width = out_w_wise.size(3)
        if width != w_wise_width:
            if w_wise_width < width:
                out_w_wise = F.interpolate(out_w_wise, mode='bilinear', size=(height, width), align_corners=True)
            else:
                out_w_wise = F.adaptive_avg_pool2d(out_w_wise, output_size=(height, width))

        # Merge. Output will be 3C x H X W
        outputs = torch.cat((out_ch_wise, out_h_wise, out_w_wise), 1)
        outputs = self.br_act(outputs)

        if self.shuffle:
            outputs = self.vol_shuffle(outputs)
        outputs = self.weight_avg_layer(outputs)
        linear_wts = self.linear_comb_layer(outputs)
        proj_out = self.proj_layer(outputs)
        return proj_out * linear_wts

    def __repr__(self):
        s = '{name}(in_channels={channel_in}, out_channels={channel_out}, kernel_size={ksize}, vol_shuffle={shuffle}, ' \
            'width={width}, height={height}, dilation={dilation})'
        return s.format(name=self.__class__.__name__, **self.__dict__)

In [None]:
class aspp_dice(nn.Module):

    """
        Module that implements the ASPP block and standard convolutions replaced by dimension-wise convolutions
    """
    
    def __init__(self,in_channels,mid_channels,prev_dim,rates):

        """
            Constructor to initialize the ASPP_DiCE block instance
            
            Arguments:
                in_channels   : The number of channels in the dense feature map after the encoder pipeline
                mid_channels  : The number of feature maps to be output by each of the convolution blocks
                prev_dim      : size of the input dense feature map after the encoder pipeline
                rates         : Dilation rates of the three convolution blocks
        """
        
        super().__init__()

        r1,r2,r3 = rates

        self.branch1 = dice(in_channels,mid_channels,height = 32,width = 32,kernel_size = 1)
        self.branch2 = dice(in_channels,mid_channels,height = 32,width = 32,kernel_size = 3,dilation = [1,r1,r1])
        self.branch3 = dice(in_channels,mid_channels,height = 32,width = 32,kernel_size = 3,dilation = [1,r2,r2])
        self.branch4 = dice(in_channels,mid_channels,height = 32,width = 32,kernel_size = 3,dilation = [1,r3,r3])
        self.branch5 = nn.AvgPool2d(kernel_size = prev_dim)

        self.prev_dim = prev_dim

        self.upsample = nn.UpsamplingBilinear2d(size = prev_dim)

        self.final_layer = dice(mid_channels * 4 + in_channels,in_channels,height = 32,width = 32,kernel_size = 1)

    def forward(self,x):

        """
            Method to implement the forward pass of the ASPP_DiCE block
            
            Arguements: 
                x   : Input to the ASPP_DiCE block
                
            Returns: 
                out : Output of the ASPP_DiCE block
        """
        
        out1 = self.upsample(self.branch1(x))
        out2 = self.upsample(self.branch2(x))
        out3 = self.upsample(self.branch3(x))
        out4 = self.upsample(self.branch4(x))
        out5 = self.upsample(self.branch5(x))

        out = tor.cat((out1,out2,out3,out4,out5),dim = 1)

        out = self.final_layer(out)

        return out

In [None]:
class dist_dice(nn.Module):

    """
        Module that implements KidneyNet
    """
    
    def __init__(self,in_channels,nfeat,mid_channels = 256):
        
        """
            Constructor to initialize the KidneyNet instance
            
            Arguments: 
                in_channels  : The number of channels of the input tissue image
                nfeat        : Hyper-parameter that dictates the number of feature maps 
                mid_channels : The number of channels to be considered for the ASPP_DiCE block
        """

        super().__init__()
        
        self.conv1a = nn.Conv2d(in_channels = in_channels,out_channels = nfeat,kernel_size = (3,3),padding = 1)
        self.bn1a = nn.BatchNorm2d(nfeat)
        self.conv1b = dice(nfeat,nfeat,height = 256,width = 256)
        self.bn1b = nn.BatchNorm2d(nfeat)

        self.maxpool1 = nn.MaxPool2d(kernel_size = (2,2), stride = 2)

        self.conv2a = dice(nfeat,2 * nfeat,width = 128,height = 128)
        self.bn2a = nn.BatchNorm2d(2 * nfeat)
        self.conv2b = dice(2 * nfeat,2 * nfeat,width = 128,height = 128)
        self.bn2b = nn.BatchNorm2d(2 * nfeat)

        self.maxpool2 = nn.MaxPool2d(kernel_size = (2,2), stride = 2)

        self.conv3a = dice(2 * nfeat,4 * nfeat,width = 64,height = 64)
        self.bn3a = nn.BatchNorm2d(4 * nfeat)
        self.conv3b = dice(4 * nfeat,4 * nfeat,width = 64,height = 64)
        self.bn3b = nn.BatchNorm2d(4 * nfeat)

        self.maxpool3 = nn.MaxPool2d(kernel_size = (2,2),stride = 2)

        self.conv4a = dice(4 * nfeat,8 * nfeat,width = 32,height = 32)
        self.bn4a = nn.BatchNorm2d(8 * nfeat)
        self.conv4b = dice(8 * nfeat,8 * nfeat,width = 32,height = 32)
        self.bn4b = nn.BatchNorm2d(8 * nfeat)

        self.upconv1 = nn.ConvTranspose2d(in_channels = 8 * nfeat,out_channels = 4 * nfeat, kernel_size = (3,3), stride = 2, padding = 1, output_padding = 1)

        self.conv6a = dice(8 * nfeat,4 * nfeat,width = 64,height = 64)
        self.bn6a = nn.BatchNorm2d(4 * nfeat)
        self.conv6b = dice(4 * nfeat,4 * nfeat,width = 64,height = 64)
        self.bn6b = nn.BatchNorm2d(4 * nfeat)

        self.upconv2 = nn.ConvTranspose2d(in_channels = 4 * nfeat,out_channels = 2 * nfeat, kernel_size = (3,3), stride = 2, padding = 1, output_padding = 1)

        self.conv7a = dice(4 * nfeat,2 * nfeat,width = 128,height = 128)
        self.bn7a = nn.BatchNorm2d(2 * nfeat)
        self.conv7b = dice(2 * nfeat,2 * nfeat,width = 128,height = 128)
        self.bn7b = nn.BatchNorm2d(2 * nfeat)

        self.upconv3 = nn.ConvTranspose2d(in_channels = 2 * nfeat,out_channels = nfeat, kernel_size = (3,3), stride = 2, padding = 1, output_padding = 1)

        self.conv8a = dice(2 * nfeat,nfeat,width = 256,height = 256)
        self.bn8a = nn.BatchNorm2d(nfeat)
        self.conv8b = dice(nfeat,nfeat,width = 256,height = 256)
        self.bn8b = nn.BatchNorm2d(nfeat)

        self.seg_map_conv = nn.Conv2d(in_channels = nfeat,out_channels = 1, kernel_size = (1,1))

        self.aspp = aspp_dice(in_channels = 8 * nfeat,mid_channels = mid_channels,prev_dim = (32,32),rates = [2,4,6])

    def forward(self,inp):
        
        """
            Method to perform the forward pass of KidneyNet
            
            Arguments:
                inp   : Input tissue image to KidneyNet
                
            Returns  : 
                out   : The distance map of the corresponding tissue image
        """

        out1 = nn.ReLU()(self.bn1a(self.conv1a(inp)))
        out1 = nn.ReLU()(self.bn1b(self.conv1b(out1)))
        out2 = self.maxpool1(out1)

        out2 = nn.ReLU()(self.bn2a(self.conv2a(out2)))
        out2 = nn.ReLU()(self.bn2b(self.conv2b(out2)))
        out3 = self.maxpool2(out2)

        out3 = nn.ReLU()(self.bn3a(self.conv3a(out3)))
        out3 = nn.ReLU()(self.bn3b(self.conv3b(out3)))
        out4 = self.maxpool3(out3)

        out4 = nn.ReLU()(self.bn4a(self.conv4a(out4)))
        out4 = nn.ReLU()(self.bn4b(self.conv4b(out4)))
        
        out_aspp = self.aspp(out4)

        out6 = self.upconv1(out_aspp)

        out6 = tor.cat((out6,out3),dim = 1)
        del out4

        out6 = nn.ReLU()(self.bn6a(self.conv6a(out6)))
        out6 = nn.ReLU()(self.bn6b(self.conv6b(out6)))
        out7 = self.upconv2(out6)

        out7 = tor.cat((out7,out2),dim = 1)
        del out3
        del out6

        out7 = nn.ReLU()(self.bn7a(self.conv7a(out7)))
        out7 = nn.ReLU()(self.bn7b(self.conv7b(out7)))
        out8 = self.upconv3(out7)

        out8 = tor.cat((out8,out1),dim = 1)
        del out2
        del out7

        out8 = nn.ReLU()(self.bn8a(self.conv8a(out8)))
        out8 = nn.ReLU()(self.bn8b(self.conv8b(out8)))
        
        out = nn.ReLU()(self.seg_map_conv(out8))

        return out

In [None]:
def get_f1(gt,pred):

    """
        Method to find the mean f1 score given a batch of segmentation predictions and ground truth segmentation images
    
        Arguments:
            gt   : Ground truth segmentation images
            pred : Predicted segmentation images
            
        Returns  :
            Mean f1 score of the given batch of predictions
    """
    
    f1 = []
    m = gt.shape[0]

    if(not isinstance(gt,np.ndarray)):
        gt = gt.detach().cpu().numpy().squeeze()

    if(not isinstance(pred,np.ndarray)):
        pred = pred.detach().cpu().numpy().squeeze()

    for predicted,ground_truth in zip(pred,gt):
        predicted = ((predicted - predicted.min()) / (predicted.max() - predicted.min())) * 1
        predicted = np.uint8(predicted)

        ground_truth = ((ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min())) * 1
        ground_truth = np.uint8(ground_truth)

        predicted = predicted.flatten()
        ground_truth = ground_truth.flatten()

        predicted = np.uint8(predicted)
        ground_truth = np.uint8(ground_truth)

        f1.append(f1_score(ground_truth,predicted))

    return np.mean(f1)

In [None]:
def get_ji(gt,pred):
    
    """
        Method to find the mean jaccard index given a batch of segmentation predictions and ground truth segmentation images
    
        Arguments:
            gt   : Ground truth segmentation images
            pred : Predicted segmentation images
            
        Returns  :
            Mean jaccard index of the given batch of predictions
    """

    ji = []
    m = gt.shape[0]

    if(not isinstance(gt,np.ndarray)):
        gt = gt.detach().cpu().numpy().squeeze()

    if(not isinstance(pred,np.ndarray)):
        pred = pred.detach().cpu().numpy().squeeze()

    for predicted,ground_truth in zip(pred,gt):
        predicted = ((predicted - predicted.min()) / (predicted.max() - predicted.min())) * 1
        predicted = np.uint8(predicted)

        ground_truth = ((ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min())) * 1
        ground_truth = np.uint8(ground_truth)

        predicted = predicted.flatten()
        ground_truth = ground_truth.flatten()

        predicted = np.uint8(predicted)
        ground_truth = np.uint8(ground_truth)

        ji.append(jaccard_score(ground_truth,predicted))

    return np.mean(ji)

In [None]:
## Initializing the elastic transform method

elastic = A.Compose([
        A.ElasticTransform(alpha = 1)
    ],
    additional_targets={"mask2" : "mask"})

In [None]:
class dataset(torch.utils.data.Dataset):

    """
        Module that implements dataset related utilities
    """
    
    total_mean = 0

    def __init__(self,files_dir,data_size = -1,phase = "",apply_transforms = True):
        
        """
            Constructor that initializes the dataset class instance
            
            Arguments: 
                files_dir        : Directory where the dataset is placed
                data_size        : The number of data samples in the dataset
                phase            : Training / validation / testing
                apply_transforms : flag to toggle applying data augmentation related transforms
        """

        data_dir = os.path.join(files_dir,"data")
        gt_dir = os.path.join(files_dir,"gts")

        files = os.listdir(gt_dir)
    
        data_files = [os.path.join(data_dir,x) for x in files]
        gt_files = [os.path.join(gt_dir,x) for x in files]

        if(data_size == -1):
            data_size = len(data_files)

        self.data_files = data_files
        self.gt_files = gt_files
        self.data_size = data_size
        self.apply_transforms = apply_transforms

        if(phase == "train"):
          dataset.total_mean = tor.from_numpy(calc_total_mean(self.data_files))
          dataset.total_mean = dataset.total_mean.permute(2,0,1)
          
        print("shape of dataset.total_mean: ",dataset.total_mean.size())

        del data_files
        del gt_files

    def __len__(self):
        
        """
            Method to return the dataset size
        """
        
        return self.data_size

    def transforms(self,data,label,gt):
        
        """
            Method to apply transformations on the input image for data augmentation
            
            Arguments: 
                data  : The tissue image
                label : The corresponding ground truth distance map
                gt    : The ground truth segementation image
                
            Returns  : 
                data,label,gt after data augmentations through transormations
        """

        ## Random Crop
        i, j, h, w = transforms.RandomCrop.get_params(
            data, output_size=(256,256))
        data = TF.crop(data, i, j, h, w)
        label = TF.crop(label, i, j, h, w)
        gt = TF.crop(gt,i,j,h,w)

        ## Random horizontal flip
        if(random.random() > 0.5):
            data = TF.hflip(data)
            label = TF.hflip(label)
            gt = TF.hflip(gt)

        ## Random Vertical Flip
        if(random.random() > 0.5):
            data = TF.vflip(data)
            label = TF.vflip(label)
            gt = TF.vflip(gt)

        ## Random rotate in multiples of 90
        range_of_angles = [0,90,180,270]
        angle = random.choice(range_of_angles)
        data = TF.rotate(data,angle)
        label = TF.rotate(label,angle,fill = (0,))
        gt = TF.rotate(gt,angle,fill = (0,))
        
        ## Applying color-jitter to data
        data = transforms.ColorJitter(brightness=0.3, contrast=0.3)(data)

        return data,label,gt

    def __getitem__(self,idx):

        """
            Method to obtain a triad of data,label,gt after pre-processing
            
            Arguments:
                idx  : The index of the sample from the dataset
                
            Returns  : 
                The triad data,label,gt
        """
        
        data = self.data_files[idx]
        label = self.label_files[idx]
        gt = self.gt_files[idx]

        data = im.open(data)
        gt = cv2.imread(gt,0)
        
        ## Dynamic transformation of distance map
        label = ndimage.distance_transform_edt(gt)
        
        gt = im.fromarray(gt)
        label = im.fromarray(label)

        if(self.apply_transforms):
          data,label,gt = self.transforms(data,label,gt)
        else:
          data = data.resize((256,256))
          label = label.resize((256,256))
          gt = gt.resize((256,256))


        if(self.apply_transforms):

            data = np.array(data).reshape(data.size[0], data.size[1], -1)[:,:,:3]
            label = np.array(label).reshape(label.size[0], label.size[1])
            gt = np.array(gt).reshape(gt.size[0], gt.size[1])

            result = elastic(image = data,mask = gt,mask2 = label)

            data = result["image"]
            label = result["mask2"]
            gt = result["mask"]

        ## Convert PILs to tensors
        data = transforms.ToTensor()(data)[:3,:,:]
        label = transforms.ToTensor()(label)
        gt = transforms.ToTensor()(gt)

        label = label.type(tor.FloatTensor)
        gt = gt.type(tor.LongTensor)
        
        data = ((data - data.min()) / (data.max() - data.min())) * 255
        label = ((label - label.min()) / (label.max() - label.min())) * 255

        ## subtracting the mean
        data = data - dataset.total_mean

        return data,label,gt

In [None]:
train_dir = "./dataset_dir/train/"

trainset = dataset(files_dir = train_dir,phase = "train",apply_transforms=True)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = 8,shuffle=True)

In [None]:
val_dir = "./dataset_dir/val"

valset = dataset(files_dir=val_dir,apply_transforms=False)
valloader = torch.utils.data.DataLoader(valset,batch_size=8,shuffle = True)

In [None]:
def get_segments(distmaps,param = 7,thresh1 = 0.5,thresh2 = 5):

    """
        Method to perform post-processing on a batch of distance map predictions
        
        Arguments:
            distmaps  : Predictions of distance maps
            param     : parameters to for post processing
            thresh1   : th2
            thresh2   : th1
    """
    
    segments = []

    for res in distmaps:

        if(not isinstance(res,np.ndarray)):
            res = res.detach().cpu().numpy().squeeze()
            
        res = ((res - res.min()) / (res.max() - res.min())) * 255
        res = np.uint8(res)
        res[res<thresh2] = 0 

        res = pp.PostProcess(res,param = param,thresh = thresh1)
        res[res!=0] = 1

        segments.append(res)

    segments = np.array(segments)

    return segments

In [None]:
def train(seg,epochs,dataloaders,hyper_params,reset = True,save = False):

    """
        Method to perform training
        
        Arguments: 
            seg         : KidneyNet instance
            epochs      : Total number of epochs to train
            dataloaders : list of dataloaders
            hyper_params: list of hyper-parameters
            reset       : Flag that toggles the reset operation
            save        : Flag that toggles the saving operation
            
        Returns:
            epoch_losses: training history
    """
    
    trainloader,valloader = dataloaders

    lr,reg,nfeat,postproc_params = hyper_params

    if(postproc_params):
      param,thresh1,thresh2 = postproc_params
    else:
      param = 7
      thresh1 = 0.5
      thresh2 = 20

    if(reset):
        del seg
        seg = dist_dice(in_channels=3,nfeat = nfeat).to(device)
        print("/////////////////// Weights have been reset \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\")

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(seg.parameters(),lr = lr,weight_decay = reg)

    epoch_losses = []
    for epoch in range(epochs):

      batch_losses = []
      batch_f1 = []
      batch_ji = []

      for batch_idx,(data,label,gt) in enumerate(trainloader):

          data,label = data.to(device),label.to(device)

          optimizer.zero_grad()

          out = seg(data) 
          loss = criterion(out,label)
          loss.backward()
          optimizer.step()

          segment_maps = get_segments(out,param = param,thresh1 = thresh1,thresh2 = thresh2)

          batch_f1.append(get_f1(gt,segment_maps))
          batch_ji.append(get_ji(gt,segment_maps))

          batch_losses.append(loss.item())

      epoch_losses.append(np.mean(batch_losses))

      print("Epoch: ",epoch,"\tEpoch Loss: ",epoch_losses[-1],"Mean f1: ",np.mean(batch_f1))
      print("Mean ji: ",np.mean(batch_ji),"HM: ",stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)]))

      seg.eval()
      with tor.no_grad():

          batch_losses = []
          batch_f1 = []
          batch_ji = []

          for batch_idx,(valdata,vallabel,valgt) in enumerate(valloader):

              valdata,vallabel = valdata.to(device),vallabel.to(device)

              valout = seg(valdata)
              loss = criterion(valout,vallabel)

              segment_maps = get_segments(valout,param = param,thresh1 = thresh1,thresh2 = thresh2)
              batch_losses.append(loss.item())
              batch_f1.append(get_f1(valgt,segment_maps))
              batch_ji.append(get_ji(valgt,segment_maps))

          print("Val Loss: ",np.mean(batch_losses),"Mean f1: ",np.mean(batch_f1))
          print("Mean ji: ",np.mean(batch_ji),"HM: ",stats.harmonic_mean([np.mean(batch_f1),np.mean(batch_ji)]))


      print("-------------------------------------------------------------------------")
      seg.train()

    return epoch_losses

In [None]:
## setting up training

device = tor.device("cuda:0" if tor.cuda.is_available() else "cpu")
print("using: ",device)

nfeat = 64
seg = dist_dice(in_channels=3,nfeat = nfeat).to(device)

In [None]:
## Training Loop

epochs = 20
nfeat = 64
lr = 1e-3
reg = 1e-3
reset = True
dataloaders = [trainloader,valloader]
postproc_params = [7,0.5,28]
hyper_params = [lr,reg,nfeat,postproc_params]

_ = train(seg,epochs,dataloaders,hyper_params = hyper_params,reset = reset)

In [None]:
## Save the model after training

model = seg.state_dict()

tor.save(model,"model.pt")