# Packages

In [142]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import tifffile as tiff
import os
from tqdm import tqdm
import pandas as pd

import cv2

# Utility Functions

In [10]:
def mask_rle_to_2d(self):
        """
        converts mask from run length encoding to 2D numpy array
        """
        dx = self.dx
        dy = self.dy

    
        mask = np.zeros(dx*dy, dtype=np.uint8)
        s = self.mask_rle.split()  # split the rle encoding
        for i in range(len(s)//2):
            start = int(s[2*i])-1
            length = int(s[2*i+1])
            mask[start:start+length] = 1
        self.mask_2d = mask.reshape(dy, dx).T
        
        self.mask_2d = utils.mask_rle_to_2d(self.mask_rle, dx, dy)
        
    
def mask_2d_to_rle(self, mask_2d):
        """
        Takes a 2D mask of 0/1 and returns the run length encoded form
        """

        mask = mask_2d.T.reshape(-1)  # order by columns and flatten to 1D
        mask_padded = np.pad(mask, 1)  # pad zero on both sides
        #find the start positions of the 1's
        starts = np.where((mask_padded[:-1] == 0) & (mask_padded[1:] == 1))[0]
        #find the end positions of 1's for each run
        ends = np.where((mask_padded[:-1] == 1) & (mask_padded[1:] == 0))[0]

        rle = np.zeros(2*len(starts))
        
        rle[::2] = starts
        #length of each run = end position - start position
        rle[1::2] = ends - starts
        rle = rle.astype(int)
        return rle

In [11]:
def get_padsize(img, reduce, sz):

    shape = img.shape
    print(shape)

    pad0 = (reduce*sz - shape[0] % (reduce*sz)) % (reduce*sz)
    pad1 = (reduce*sz - shape[1] % (reduce*sz)) % (reduce*sz)
    pad_x = (pad0//2, pad0-pad0//2)
    pad_y = (pad1//2, pad1-pad1//2)

    return pad_x, pad_y


def check_threshold(img_BGR, sat_threshold, pixcount_th):

    """
    checks if an input image passes the threshold conditions:
    conditions:
    not black--> sum of pixels exceeed a threshold = pixcount_th
    saturation --> number of pixels with saturation > sat_threshold exceeds pixcount_th
    Returns:
    True if both conditions are met else False
    """
    #if most of the pixels are black, return False
    #edge of each image is typically black
    if img_BGR.sum() < pixcount_th:
        return False

    #convert to hue, saturation, Value in openCV
    hsv = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    # if less than prefined number of values are above a saturation threshold, return False
    #this is typically the gray background around the biological object
    if (s > sat_threshold).sum() < pixcount_th:
        return False

    return True

In [132]:
class Image():
    def __init__(self, img, img_name =None):
        self.img = img
        self.shape = img.shape
        self.name = img_name

        self.image_reshape()
        self.dx = self.shape[0]
        self.dy = self.shape[1]

        self.tile_size = None

        self.pad_x = None
        self.pad_y = None
        self.tiled_img = None

        self.mask_rle = None
        self.mask_2d = None
        self.tiled_mask = None
        
    
    def image_reshape(self):
        
        if len(self.shape) == 5:
            self.img = np.transpose(self.img.squeeze(), (1, 2, 0))
            self.shape = self.img.shape
            
    
    def split_image_mask_into_tiles(self, reduce=1, sz=512):
     
        self.tile_size = sz

        self.pad_x, self.pad_y = get_padsize(self.img, reduce, sz)
        print(self.pad_x, self.pad_y)
        #Create padded Image and padded mask2D
        img_padded  = np.pad(self.img, [self.pad_x, self.pad_y, (0, 0)], constant_values=0)
        mask_padded = np.pad(self.mask_2d, [self.pad_x, self.pad_y], constant_values = 0)

        print("shape of image after padding:: ", img_padded.shape,
            img_padded.shape[0]//sz, img_padded.shape[1]//sz)

        print("shape of mask after padding:: ", mask_padded.shape,
              mask_padded.shape[0]//sz, mask_padded.shape[1]//sz)

        #tile the padded image
        img_reshaped = img_padded.reshape(
            img_padded.shape[0]//sz, sz, img_padded.shape[1]//sz, sz, 3)
        img_reshaped = img_reshaped.transpose(0, 2, 1, 3, 4).reshape(-1, sz, sz, 3)

        #tile the padded mask2D
        mask_reshaped = mask_padded.reshape(
            mask_padded.shape[0]//sz, sz, mask_padded.shape[1]//sz, sz)
        mask_reshaped = mask_reshaped.transpose(
            0, 2, 1, 3).reshape(-1, sz, sz)

        self.tiled_img = img_reshaped
        self.tiled_mask = mask_reshaped
            
            
    def save_thresholded_image(self, tiled_threshold_img_dir, mask_tile_dict, sat_threshold=40, pixcount_th=200):
        """
        instead of saving, check thresholding of image
        if it passes threshold then do an inference else predict a mask of all zeros
        """
        
        n = self.tiled_img.shape[0] 

        valid_img_count = 0
        valid_idx = []
        print(f"Original tiled image count = {n}")

        for i in range(n):
            img_BGR = self.tiled_img[i, :, :, :]
            if utils.check_threshold(img_BGR, sat_threshold, pixcount_th):
                valid_img_count += 1
                valid_idx.append(i)
                
                #create an id for the image tile
                img_tile_id = f"{self.name}_{str(self.tile_size)}_{str(valid_img_count)}_{str(i)}"
                ###img_name = img_tile_id+'.png'  # name of the saved image tile

                mask_for_tile = self.tiled_mask[i, :, :]  # get the mask for the tile
                #convert the mask for the tile to rle
                mask_rle = self.mask_2d_to_rle(mask_for_tile)
                #save the rle mask to a dict, key = name of the corresponding image tile
                mask_tile_dict[img_tile_id] = mask_rle
            else:
                mask_for_tile = np.zeros(mask_for_tile)
                
                
                ###if valid_img_count == 1001:
                ###cv2.imwrite(os.path.join(tiled_threshold_img_dir, img_name), img_BGR)

        print(f"Image count after thresholding = {valid_img_count}")

# Model

In [124]:
class BaseModel(nn.Module):
    
    def __init__(self):
        super(BaseModel, self).__init__()

    def forward(self):
        raise NotImplementedError

    def __str__(self):
        model_params = filter(lambda x: x.requires_grad, self.parameters())

        return super(BaseModel, self).__str__()
    
    
class Conv2x(nn.Module):
    '''
    preserves the the size of the image
    '''
    def __init__(self, in_ch, out_ch, inner_ch=None):
        super(Conv2x, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.inner_ch = out_ch//2 if inner_ch is None else inner_ch

        self.conv2d_1 = nn.Conv2d(self.in_ch, self.inner_ch,
                                  kernel_size=3, padding=1, bias=False)
        self.conv2d_2 = nn.Conv2d(self.inner_ch, self.out_ch,
                                  kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inner_ch)
        self.bn2 = nn.BatchNorm2d(self.out_ch)

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2d_2(x)
        x = self.bn2(x)
        x = F.relu(x)

        return x
    
    
class encoder(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(encoder, self).__init__()
        self.conv2x = Conv2x(in_ch, out_ch)
        self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)

    def forward(self, x):
        x = self.conv2x(x)
        x = self.pool(x)
        return x
    
    
class decoder(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(decoder, self).__init__()
        self.transposeconv = nn.ConvTranspose2d(
            in_ch, in_ch//2, kernel_size=2, stride=2)
        self.conv2x = Conv2x(in_ch, out_ch)


    def forward(self, x_down, x_up, interpolate=True):

        x_up = self.transposeconv(x_up)

        #check for matching dims before concatenating

        if (x_up.size(2) != x_up.size(2)) or (x_up.size(3) != x_up.size(3)):
            if interpolate:
                x_up = F.interpolate(x_up, size=(x_down.size(2), x_down.size(3)),
                mode="bilinear", align_corners=True)
        
        #Concat features from down conv channel and current up-conv
        #along channel dim =1
        x_up = torch.cat([x_up, x_down], dim=1) 
        x_up = self.conv2x(x_up)

        return x_up
    


class UNet(BaseModel):

    def __init__(self, in_ch=3, conv_channels=[16, 32, 64, 128, 256]):
        super(UNet, self).__init__()

        self.conv_channels = conv_channels
        self.conv_start = Conv2x(in_ch, conv_channels[0]) #output_size = input_size
        self.down1 = encoder(conv_channels[0], conv_channels[1])   #output_size = input_size/2
        self.down2 = encoder(conv_channels[1], conv_channels[2])   #output_size = input_size/2
        self.down3 = encoder(conv_channels[2], conv_channels[3])   #output_size = input_size/2
        self.down4 = encoder(conv_channels[3], conv_channels[4])   #output_size = input_size/2

        self.conv_middle = Conv2x(conv_channels[4], conv_channels[4]) #output_size = input_size

        self.up4 = decoder(conv_channels[4], conv_channels[3]) #output_size = input_size * 2
        self.up3 = decoder(conv_channels[3], conv_channels[2]) #output_size = input_size * 2
        self.up2 = decoder(conv_channels[2], conv_channels[1]) #output_size = input_size * 2
        self.up1 = decoder(conv_channels[1], conv_channels[0]) #output_size = input_size * 2

        self.final_conv = nn.Conv2d(self.conv_channels[0], 1, kernel_size=1)

        self.init_params()
    
    
    def init_params(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()

            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
                
    
    
    def forward(self, x):
        # size of x = [B, _, nx, ny]
        
        x1 = self.conv_start(x)  # size of x = [B, self.conv_channels[0], nx, ny]
        x2 = self.down1(x1)  # size of x = [B, self.conv_channels[1], nx/2, ny/2]
        x3 = self.down2(x2)  # size of x = [B, self.conv_channels[2], nx/4, ny/4]
        x4 = self.down3(x3)  # size of x = [B, self.conv_channels[3], nx/8, ny/8]
        x5 = self.down4(x4)  # size of x = [B, self.conv_channels[4], nx/16, ny/16]

        x = self.conv_middle(x5)  # size of x = [B, self.conv_channels[4], nx/16, ny/16]

        x = self.up4(x4, x)       # size of x = [B, self.conv_channels[3], nx/8, ny/8]
        x = self.up3(x3, x)       # size of x = [B, self.conv_channels[2], nx/4, ny/4]
        x = self.up2(x2, x)       # size of x = [B, self.conv_channels[1], nx/2, ny/2]
        x = self.up1(x1, x)       # size of x = [B, self.conv_channels[0], nx, ny]

        x = self.final_conv(x)

        return x

# Metric 

In [15]:
def metric_dice_iou(output, target, smooth = 0.005):
        tp = (output * target).sum(axis=(1,2)) #intersection
        fp = (output * (1.0 - target)).sum(axis=(1,2)) #false positives
        fn = ((1.0 - output) * target).sum(axis=(1,2)) #false negatives
        dice = np.mean((2.0 * tp + smooth) / (2 * tp + fp + fn + smooth))
        iou = np.mean((tp + smooth) / (tp + fp + fn + smooth))

        return dice, iou

In [54]:
# Trainer, checkpoint configs

class BaseTrainer:

    def __init__(self, model, loss, config, train_loader, val_loader):
        self.model = model
        self.loss = loss
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.start_epoch = 1

        self.device = self.get_device()

        self.model.to(self.device)

        #Training configs
        cfg_train = config['trainer']
        self.epochs = cfg_train['epochs']
        self.save_period = cfg_train['save_period']

        #Checkpoint configs
        cur_dir = os.curdir
        self.checkpoint_dir = os.path.join(cur_dir, cfg_train["save_dir"])
        print(self.checkpoint_dir)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

# Perform inference

In [140]:
def inference_single_image(img):
    '''
    if self.val_loader is None:
        print(f"No val loader exists")
        return {}
    '''
    model = UNet() 

    masks = '/Users/Ethan/Documents/Documents/Documents - Ethan’s MacBook Pro/python/kidney_files/hubmap-kidney-segmentation/train.csv'
    df_masks = pd.read_csv(masks).set_index('id')
    mask_image = df_masks.loc['aaa6a05cc']

    # save checkpoint
    checkpoint_path = '/Users/Ethan/Documents/Documents/Documents - Ethan’s MacBook Pro/python/kidney_files/hubmap-kidney-segmentation/model_checkpoints.pth'
    store_model_checkpoint = torch.save(model.state_dict(),checkpoint_path)
    
    # load state
    model.load_state_dict(torch.load(checkpoint_path)) 

    #resume checkpoint
    #_resume_checkpoint(checkpoint_path)
    
    model.eval()
    #model._reset_metrics()
    #tbar = tqdm(self.val_loader, ncols=100)

    
    #read an image
    datadir_train = '/Users/Ethan/Documents/Documents/Documents - Ethan’s MacBook Pro/python/kidney_files/hubmap-kidney-segmentation/train'
    img_index = 6
    raw_image_files = [f for f in os.listdir(datadir_train) if "tiff" in f]
    f = raw_image_files[img_index]
    raw_file_name = f.split('.')[0]
    img_raw = tiff.imread(os.path.join(datadir_train, f))


    #instantiate image class
    raw_img = Image(img_raw,img_name=raw_file_name)
    #call Image.split_image_mask_into_tiles(self, reduce=1, sz=512) to create
    raw_img.split_image_mask_into_tiles(reduce=1, sz=512)

    image_1 = raw_img.tiled_img

    with torch.no_grad(): 
        tile_count = image_1.shape[0]

        total_tiles = 0
        tiles_idx = []

        accumulate_predictions = np.empty((tile_count,sz,sz))

        #loop over the tiles
        for i in range(tile_count):
            img_BGR = tiled_img[i, :, :, :]
            if utils.check_threshold(img_BGR, sat_threshold, pixcount_th):
                total_tiles += 1
                tiles_idx.append(i)
                #for i, sample_batch in enumerate(tbar):
                            #for each tile
                            #check thresolhold
                            #if threshold passes, then do inference
                            ###img_tile_id = raw_img_name+'_'+str(sz)+'_'+str(total_tiles)
                img = sample_batch['image']
                #mask = sample_batch['mask'].float()

                batch_size = img.shape[0]
                img = img.to(self.device)
                #mask = mask.to(self.device)

                out = torch.squeeze(self.model(img), 1)
                        #out = 
                        # convert to 0 or 1
            else:
                out = np.zeros(img)

        #accumulate the predictions into a big np array
        accumulate_predictions[i] = inference_single_image(image)
        #prediction size = [total_tiles, sz, sz]

        reconstructed_image = raw_img.reconstruct_original_from_padded_tiled_image(tiled_img)
        # call above to get the predicted mask as the same size as the original image / mask

        #calculate metrics
        metric_dice_iou(tiled_img, raw_img)
        #calculate 2d_to_rle
        mask_pred_2d_to_rle = mask_2d_to_rle(reconstructed_image)

        return mask_pred_2d_to_rle

#call this function for all the prediction images
inference = inference_single_image(image_1)
inference

(18484, 13013, 3)
(230, 230) (149, 150)


TypeError: pad() missing 1 required positional argument: 'mode'

In [129]:
def _save_checkpoints( epoch):
        #create a state dict for saving

        model_state = {
            'epoch':epoch,
            'state_dict':model.state_dict(),
            'optimizer':optimizer.state_dict(),
            'config':config
        }
        savetime = datetime.datetime.now().strftime('%m_%d_%H_%M')
        filename = f"{self.config['name']}_{epoch}_{savetime}"
        filename = os.path.join(checkpoint_dir, f'{filename}.pth')
        torch.save(model_state, filename)

def _resume_checkpoint(self):
    pass

def _train_epoch(self):
        #implement this in Trainer (sub class of BaseTrainer) 
    raise NotImplementedError

def _val_epoch(self, epoch):
    raise NotImplementedError

In [114]:
def _resume_checkpoint(checkpoint_path):
        checkpoint = torch.load(checkpoint_path) #need this line

        last_epoch = checkpoint['epoch']
        model_name = checkpoint['config']['name']

        if model_name == config['name']: # self was here
            load_model = model.load_state_dict(checkpoint['state_dict']) #need this line
            start_epoch = last_epoch + 1
            #self removed from top two lines
            return True
        else:
            print("current model name doesn't match with previously saved model name !!")
            print("Current model name: {} , previous model name : {}".format(config['name'], model_name)) #originally self.config 
            return False

In [137]:
# read corresponding mask to compare to reconstructed  img
def comparison(mask,image_reconstructed):
    raw_img.mask_rle = df_masks.loc[raw_file_name, 'encoding']
    actual_mask_rle = raw_img.mask_rle
    if mask_pred_2d_to_rle == actual_mask_rle:
        return True
    else:
        print("incorrect encoding")
    
comparison(mask_image,reconstructed_image)

NameError: name 'reconstructed_image' is not defined

In [42]:
def reconstruct_original_from_padded_tiled_image(self, tiled_image):
        n = tiled_image.shape[0]
        tile_size = self.tile_size
        (pad_x_l, pad_x_r) = self.pad_x
        (pad_y_l, pad_y_r) = self.pad_y

        dx_padded = self.dx + pad_x_l + pad_x_r
        dy_padded = self.dy + pad_y_l + pad_y_r

        n_x = dx_padded //tile_size
        n_y = dy_padded//tile_size

        assert (n == n_x*n_y), "dimensions don't match"

        image_untiled = tiled_image.reshape(n_x, n_y, tile_size, tile_size, 3)
        image_untiled = image_untiled.transpose(0,2,1,3,4)
        image_padded = image_untiled.reshape(n_x*tile_size, n_y*tile_size, 3)

        image_unpadded = image_padded[pad_x_l: - pad_x_r, pad_y_l: -pad_y_r, :]

        assert (self.dx == image_unpadded.shape[0]), \
            "shape of original image doesn't match with unpadded image along dim = 0"
        assert (self.dy == image_unpadded.shape[1]), \
            "shape of original image doesn't match with unpadded image along dim = 1"

        return image_unpadded

In [143]:
# dont really need this cell

from skimage import io, transform
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import pickle

import matplotlib.pyplot as plt


class Dataset_Image_mask(Dataset):

    def __init__(self, data_dir, mean, std_dev, transform=None):
        super(Dataset_Image_mask, self).__init__()
        self.root_dir = data_dir
        self.transform = transform
        self.mask_dict = self.get_mask_dict()
        self.img_name_list = list(self.mask_dict.keys())
        self.len = self.__len__()
        self.normalize = transforms.Normalize(mean, std_dev)

        

    def get_mask_dict(self):
        '''
        open the pickled file containing the dict of mask in rle format
        returns
        dict: key same as imgae file name
        value: numpy array in rle format
        '''
        img_and_mask_files = os.listdir(self.root_dir)
        mask_rle_file = [x for x in img_and_mask_files if "mask" in x][0]
        mask_rle_dict = pickle.load(
            open(os.path.join(self.root_dir, mask_rle_file), 'rb'))
        return mask_rle_dict

   
    def __len__(self):
        return len(self.img_name_list)

    
    def __getitem__(self, idx):
        img_name = self.img_name_list[idx]
        img_file_name = os.path.join(self.root_dir, f"{img_name}.png")
        img = io.imread(img_file_name)
        #img = torch.tensor(img.transpose((2,0,1))).float()
        #img = img.transpose(2, 0, 1)
        
        mask_rle_np = self.mask_dict[img_name]
        #print("mask rle shape :: ", mask_rle_np.shape)
        #print(f"**\n{mask_rle_np}")
        mask_rle = " ".join(str(x) for x in mask_rle_np)
        
        mask_2d = utils.mask_rle_to_2d(mask_rle, 512, 512)
        #augment image for training
        if self.transform is not None:
            img = self.transform(img)
            mask_2d = self.transform(mask_2d)
        
        img = self.normalize(transforms.ToTensor()(img))
        mask = torch.from_numpy(mask_2d).long()

        sample = {'image': img, 'mask': mask, 'idx': img_name}
        #sample = {'image': img}
        return sample




In [94]:
# dont really need this cell

#from base.base_trainer import BaseTrainer
#from utils import dataloader, loss, metrics
#from models import unet

class Trainer(BaseTrainer):

    def __init__(self, model, loss_fn, config, train_loader, val_loader=None):
        super(Trainer, self).__init__(
            model, loss_fn, config, train_loader, val_loader)

        self.optimizer = loss.use_optimizer(model, config)
        self.bce_dice_ratio = config['loss']["bce_dice_ratio"]

    def _train_epoch(self, epoch):
        '''
        train the model for one epoch
        '''
        self.model.train()
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols = 100, miniters=50)

        for i, sample_batch in enumerate(tbar):
            self.optimizer.zero_grad()

            img = sample_batch['image']
            mask = sample_batch['mask'].float()

            batch_size = img.shape[0]

            img = img.to(self.device)
            mask = mask.to(self.device)
            #img = img.transpose((0,3,1,2))
            out = torch.squeeze(self.model(img), 1)
            
            train_loss = self.loss(out, mask, self.bce_dice_ratio)
            self.total_loss.update(train_loss.item(), batch_size)

            train_loss.backward() #perform backprop
            self.optimizer.step() #update parameters

            tbar.set_description( f"Train: Epoch: {epoch}, Avg Loss: {self.total_loss.avg:.5f}" )
            #if (i % 50 == 0):
            #    print(f"epoch: {epoch}, batch : {i}, train loss: {train_loss.item(): .5f}, train average loss: {self.total_loss.avg: .5f}")
            #    break
        return self.total_loss.avg


    def _val_epoch(self, epoch):
            
            if self.val_loader is None:
                print(f"No val loader exists")
                return {}

            self.model.eval()
            self._reset_metrics()
            tbar = tqdm(self.val_loader, ncols=100)
            with torch.no_grad():
                for i, sample_batch in enumerate(tbar):
                    img = sample_batch['image']
                    mask = sample_batch['mask'].float()

                    batch_size = img.shape[0]
                    img = img.to(self.device)
                    mask = mask.to(self.device)

                    out = torch.squeeze(self.model(img), 1)
                    val_loss = self.loss(out, mask, self.bce_dice_ratio)
                    self.total_loss.update(val_loss.item(), batch_size)
                    #if (i%10 == 0):
                    #    print(f"epoch: {epoch}, batch : {i}, val loss: {val_loss.item()}, val average loss: {self.total_loss.avg}")
                    tbar.set_description(f"Val: Epoch: {epoch}, Avg Loss: {self.total_loss.avg:.5f}")

            return self.total_loss.avg

   
    
    def _reset_metrics(self):

        self.total_loss = metrics.AverageMeter()
        
'''        
    def main():
        root_dir = 'C:\Scripts\hubmap\code'

        data_dir = 'C:\Scripts\hubmap\\train\\tiled_thresholded_512'

        mean = [0.68912, 0.47454, 0.6486]
        std_dev = [0.13275, 0.23647, 0.15536]

        #full dataset with training images and masks
        dataset = dataloader.Dataset_Image_mask(data_dir, mean, std_dev)

        n_tot = dataset.len

        #SplitS full dataset into train set and test set
        train_test_split = 0.8
        train_count = int(train_test_split * n_tot)

        test_count = dataset.len - train_count

        train_idx = list(np.random.choice(
            range(n_tot), train_count, replace=False))
        test_idx = list(set(range(n_tot)) - set(train_idx))

        print(len(train_idx), len(test_idx), n_tot - len(train_idx) - len(test_idx))

        train_ds = torch.utils.data.Subset(dataset, train_idx)
        test_ds = torch.utils.data.Subset(dataset, test_idx)

        model = unet.UNet()

        config = json.load(open('config.json'))
        b_size = config["train_loader"]["args"]["batch_size"]
        train_loader = DataLoader(
            train_ds, batch_size=b_size, shuffle=True, num_workers=0)
        b_size = config["val_loader"]["args"]["batch_size"]
        val_loader = DataLoader(
            test_ds, batch_size=b_size, shuffle=True, num_workers=0)

        trainer = Trainer(model, loss.loss_fn, config, train_loader, val_loader)
        print(f"Trainining on device: {trainer.device}")

        trainer.train()

if __name__ == "__main__":
    main()
'''

'        \n    def main():\n        root_dir = \'C:\\Scripts\\hubmap\\code\'\n\n        data_dir = \'C:\\Scripts\\hubmap\\train\\tiled_thresholded_512\'\n\n        mean = [0.68912, 0.47454, 0.6486]\n        std_dev = [0.13275, 0.23647, 0.15536]\n\n        #full dataset with training images and masks\n        dataset = dataloader.Dataset_Image_mask(data_dir, mean, std_dev)\n\n        n_tot = dataset.len\n\n        #SplitS full dataset into train set and test set\n        train_test_split = 0.8\n        train_count = int(train_test_split * n_tot)\n\n        test_count = dataset.len - train_count\n\n        train_idx = list(np.random.choice(\n            range(n_tot), train_count, replace=False))\n        test_idx = list(set(range(n_tot)) - set(train_idx))\n\n        print(len(train_idx), len(test_idx), n_tot - len(train_idx) - len(test_idx))\n\n        train_ds = torch.utils.data.Subset(dataset, train_idx)\n        test_ds = torch.utils.data.Subset(dataset, test_idx)\n\n        model = u