In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler


# Base DataLoader

In [None]:
class BaseDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, shuffle, num_workers, split_ratio):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.N_tot = len(dataset)
        self.train_samples = len(dataset)
        self.train_sampler, self.val_sampler = self._split_sampler(split_ratio)

        self.init_kawrgs = {
            'dataset':dataset, 'batch_size':batch_size, 'shuffle':self.shuffle,
            'num_workers':num_workers
        }
        super().__init__(sampler=self.train_sampler, **self.init_kawrgs)

    def _split_sampler(self, split_ratio):
        if split_ratio == 0:
            return None, None
        
        idx_full = np.arange(self.N_tot)
        np.random.seed(0)
        np.random.shuffle(idx_full)

        if isinstance(split_ratio, int):
            assert split_ratio > 0
            assert split_ratio < self.N_tot, 
            val_len = split_ratio
        else:
            val_len = int(split_ratio*self.N_tot)

        val_idx = idx_full[:val_len]
        train_idx = np.delete(idx_full, np.arange(0, val_len))

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        self.train_samples = self.N_tot - val_len

        return train_sampler, val_sampler

    def val_split(self):
        if self.val_sampler is None:
            return None
        else:
            return DataLoader(sampler=self.val_sampler, **self.init_kawrgs)

# Image Preprocessing

In [None]:
import torchvision.transforms as transforms


def mask_rle_to_2d(rle_mask, dx, dy):
    """
    converts mask from run length encoding to 2D numpy array
    """

    mask = np.zeros(dx*dy, dtype=np.uint8)
    s = rle_mask.split()  # split the rle encoding

    for i in range(len(s)//2):
        start = int(s[2*i])-1
        length = int(s[2*i+1])
        mask[start:start+length] = 1

    mask = mask.reshape(dy, dx).T

    return mask


def mask_2d_to_rle(mask_2d):
    """
    Takes a 2D mask of 0/1 and returns the run length encoded form
    """
    mask = mask_2d.T.reshape(-1) #order by columns and flatten to 1D
    mask_padded = np.pad(mask, 1) #pad zero on both sides
    #find the start positions of the 1's
    starts = np.where(mask_padded[:-1] == 1 & (mask_padded[1:] == 0))[0]
    #find the end positions of 1's for each run
    ends = np.where((mask_padded[:-1] == 0) & (mask_padded[1:] == 1))[0]
    
    rle = np.zeros(2*len(starts))
    print(starts.shape, ends.shape, rle.shape)
    rle[::2] = starts
    #length of each run = end position - start position
    rle[1::2] = ends - starts

    return rle


def get_padsize(img, reduce, sz):

    shape = img.shape
    print(shape)

    pad0 = (reduce*sz - shape[0] % (reduce*sz)) % (reduce*sz)
    pad1 = (reduce*sz - shape[1] % (reduce*sz)) % (reduce*sz)
    pad_x = (pad0//2, pad0-pad0//2)
    pad_y = (pad1//2, pad1-pad1//2)

    return pad_x, pad_y


def check_threshold(img_BGR, sat_threshold, pixcount_th):

    """
    checks if an input image passes the threshold conditions:
    conditions:
    not black--> sum of pixels exceeed a threshold = pixcount_th
    saturation --> number of pixels with saturation > sat_threshold exceeds pixcount_th
    Returns:
    True if both conditions are met else False
    """
    #if most of the pixels are black, return False
    #edge of each image is typically black
    if img_BGR.sum() < pixcount_th:
        return False

    #convert to hue, saturation, Value in openCV
    hsv = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    # if less than prefined number of values are above a saturation threshold, return False
    #this is typically the gray background around the biological object
    if (s > sat_threshold).sum() < pixcount_th:
        return False

    return True


def image_reshape(img):
    '''
    return the shape of an image in the format [x_shape, y_shape, color_channels]
    some images have the shape [1,1,3,x,y]
    change them to have the shape = [x,y,3]
    reshape image accordingly
    returns:
    shape: new shape in the form [x,y,c]
    reshaped image
    '''
    shape = img.shape

    
    if len(img.shape) == 5:
        img = np.transpose(img.squeeze(), (1, 2, 0))
        shape = img.shape

    return shape, img


def split_image_into_tiles(img, mask, reduce=4, sz=256):
    """
    Takes an input image of shape [dx, dy,3]
    pads it on all 4 sides by zeros so that final dx and dy are integral multiple of sz=256
    Then reshapes the image into [-1, sz, sz, 3]
    The first dimennsion is the number of images of size [sz, sz, 3] we get from the original image
    Returns:
    a numpy arr ay of shape [-1, sz, sz, 3]
    """

    shape, img = image_reshape(img)
    
    pad0 = (reduce*sz - shape[0] % (reduce*sz)) % (reduce*sz)
    pad1 = (reduce*sz - shape[1] % (reduce*sz)) % (reduce*sz)
    pad_x = (pad0//2, pad0-pad0//2)
    pad_y = (pad1//2, pad1-pad1//2)
    img_padded = np.pad(img, [pad_x, pad_y, (0, 0)], constant_values=0)
    print("shape of image after padding:: ",img_padded.shape, img_padded.shape[0]//sz, img_padded.shape[1]//sz)

    mask_padded = np.pad(mask, [pad_x, pad_y], constant_values=0) #pad the 2D mask for the image
    print("shape of mask padded ", mask_padded.shape)
    #tile the padded image
    img_reshaped = img_padded.reshape(
        img_padded.shape[0]//sz, sz, img_padded.shape[1]//sz, sz, 3)
    img_reshaped = img_reshaped.transpose(0, 2, 1, 3, 4).reshape(-1, sz, sz, 3)

    #tile the padded mask
    mask_tiled = mask_padded.reshape(
        img_padded.shape[0]//sz, sz, img_padded.shape[1]//sz, sz)
    mask_tiled = mask_tiled.transpose(0,2,1,3).reshape(-1, sz, sz) 

    print(f"shape final tile: {img_reshaped.shape}, shape final mask: {mask_tiled.shape}, number of tiles and mask = {img_reshaped.shape[0]}, {mask_tiled.shape[0]}")

    return img_reshaped, mask_tiled


def save_thresholded_image(img, mask, raw_img_name, sz, output_dir, mask_tile_dict, sat_threshold=40, pixcount_th=200):
    n = img.shape[0]
    valid_img_count = 0
    sat_threshold = 40
    pixcount_th = 200
    valid_idx = []
    print(f"Original tiled image count = {n}")

    for i in range(n):
        img_BGR = img[i, :, :, :]
        if check_threshold(img_BGR, sat_threshold, pixcount_th):
            valid_img_count += 1
            valid_idx.append(i)
            #img_BGR = cv2.imencode('.png', img_BGR)[1]
            #img_out.writestr(f'test_512_{i}.png', img_BGR)
            #create an id for the image tile
            img_tile_id = raw_img_name+'_'+str(sz)+'_'+str(valid_img_count)
            img_name = img_tile_id+'.png' #name of the saved image tile

            mask_for_tile = mask[i, :, :] #get the mask for the tile
            #convert the mask for the tile to rle
            mask_rle = mask_2d_to_rle(mask_for_tile)
            #save the rle mask to a dict, key = name of the corresponding image tile
            mask_tile_dict[img_tile_id]=mask_rle

            #if valid_img_count == 1001:
            cv2.imwrite(os.path.join(output_dir, img_name), img_BGR)
    print(f"Image count after thresholding = {valid_img_count}")
    
    
def image_transform(mode='train'):
    transform = None

    v_flip = transforms.RandomVerticalFlip()
    h_flip = transforms.RandomHorizontalFlip()

    if mode == 'train':
        transform = transforms.Compose([v_flip, h_flip])
    
    return transform


def main():

    datadir_train = '/media/bony/Ganga_HDD_3TB/Ganges_Backup/Machine_Learning/HuBMAP_Hacking_Kidney/hubmap-kidney-segmentation/train'
    os.chdir(datadir_train)
    masks_train = '/media/bony/Ganga_HDD_3TB/Ganges_Backup/Machine_Learning/HuBMAP_Hacking_Kidney/hubmap-kidney-segmentation/train.csv'
    image_size = 512
    output_dir = os.path.join(datadir_train, 'tiled_'+str(image_size))
    #os.mkdir(output_dir)
    print(output_dir)
    raw_image_files = [f for f in os.listdir(datadir_train) if "tiff" in f]

    df_train_masks = pd.read_csv(masks_train).set_index('id')
    img_train_list = list(df_train_masks.index)
    print(img_train_list)
    #mask_rle = df_train_masks.loc['2f6ecfcdf', 'encoding']
    mask_tile_dict = {}
    
    #loop over the original image files
    #split each image file into multiple files
    #discard the ones that have too many black pixels or uniform saturation
    for f in raw_image_files:
        raw_file_name = f.split('.')[0]
        print(raw_file_name)

        mask_rle = df_train_masks.loc[raw_file_name, 'encoding']
        #print(raw_image_files)
    
        img_raw = tiff.imread(os.path.join(datadir_train, f))
        [dx, dy, c], img_raw = image_reshape(img_raw)

        #create the 2d mask from rle 
        mask_2d = mask_rle_to_2d(mask_rle, dx, dy)
        print("shape of unpadded 2D mask ", mask_2d.shape)
        #create an array of one more diemension for the different tiles
        tiled_img , tiled_mask = split_image_into_tiles(img_raw, mask_2d, reduce=1, sz=image_size)

        print(tiled_img.shape)
        #save only those tiles that meet the saturation and black pixel count criteria
        #save_thresholded_image(tiled_img, tiled_mask, raw_file_name, image_size, output_dir, mask_tile_dict)
    #cv2.imwrite('test.png', tiled_img[1000,:,:,:])
    #np.save("tiled_mask_dict.npy", mask_tile_dict)
    #print(raw_image_files[0])
    #img_raw = tiff.imread(os.path.join(datadir_train, raw_image_files[0]))
    #print(img_raw.shape)
    
if __name__ == "__main__":
    main()

In [None]:
import pandas as pd 
import os, sys
import tifffile as tiff
import zipfile
import json
import pickle
import preprocessing_utils as utils
import cv2


class Image():
    def __init__(self, img, img_name =None):
        self.img = img
        self.shape = img.shape
        self.name = img_name

        self.image_reshape()
        self.dx = self.shape[0]
        self.dy = self.shape[1]

        self.tile_size = None

        self.pad_x = None
        self.pad_y = None
        self.tiled_img = None

        self.mask_rle = None
        self.mask_2d = None
        self.tiled_mask = None
        
    def image_reshape(self):
    
        if len(self.shape) == 5:
            self.img = np.transpose(self.img.squeeze(), (1, 2, 0))
            self.shape = self.img.shape
            
    
    def split_image_mask_into_tiles(self, reduce=1, sz=512):
     
        self.tile_size = sz

        self.pad_x, self.pad_y = utils.get_padsize(self.img, reduce, sz)
        print(self.pad_x, self.pad_y)
        #Create padded Image and padded mask2D
        img_padded  = np.pad(self.img, [self.pad_x, self.pad_y, (0, 0)], constant_values=0)
        mask_padded = np.pad(self.mask_2d, [self.pad_x, self.pad_y], constant_values = 0)

        print("shape of image after padding:: ", img_padded.shape,
            img_padded.shape[0]//sz, img_padded.shape[1]//sz)

        print("shape of mask after padding:: ", mask_padded.shape,
              mask_padded.shape[0]//sz, mask_padded.shape[1]//sz)

        #tile the padded image
        img_reshaped = img_padded.reshape(
            img_padded.shape[0]//sz, sz, img_padded.shape[1]//sz, sz, 3)
        img_reshaped = img_reshaped.transpose(0, 2, 1, 3, 4).reshape(-1, sz, sz, 3)

        #tile the padded mask2D
        mask_reshaped = mask_padded.reshape(
            mask_padded.shape[0]//sz, sz, mask_padded.shape[1]//sz, sz)
        mask_reshaped = mask_reshaped.transpose(
            0, 2, 1, 3).reshape(-1, sz, sz)

        self.tiled_img = img_reshaped
        self.tiled_mask = mask_reshaped
        
    
    def save_thresholded_image(self, tiled_threshold_img_dir, mask_tile_dict, sat_threshold=40, pixcount_th=200):
        n = self.tiled_img.shape[0]

        valid_img_count = 0
        valid_idx = []
        print(f"Original tiled image count = {n}")

        for i in range(n):
            img_BGR = self.tiled_img[i, :, :, :]
            if utils.check_threshold(img_BGR, sat_threshold, pixcount_th):
                valid_img_count += 1
                valid_idx.append(i)
                
                #create an id for the image tile
                img_tile_id = f"{self.name}_{str(self.tile_size)}_{str(valid_img_count)}_{str(i)}"
                img_name = img_tile_id+'.png'  # name of the saved image tile

                mask_for_tile = self.tiled_mask[i, :, :]  # get the mask for the tile
                #convert the mask for the tile to rle
                mask_rle = self.mask_2d_to_rle(mask_for_tile)
                #save the rle mask to a dict, key = name of the corresponding image tile
                mask_tile_dict[img_tile_id] = mask_rle

                #if valid_img_count == 1001:
                cv2.imwrite(os.path.join(tiled_threshold_img_dir, img_name), img_BGR)

        print(f"Image count after thresholding = {valid_img_count}")


    def mask_rle_to_2d(self):
        """
        converts mask from run length encoding to 2D numpy array
        """
        dx = self.dx
        dy = self.dy

    
        mask = np.zeros(dx*dy, dtype=np.uint8)
        s = self.mask_rle.split()  # split the rle encoding
        for i in range(len(s)//2):
            start = int(s[2*i])-1
            length = int(s[2*i+1])
            mask[start:start+length] = 1
        self.mask_2d = mask.reshape(dy, dx).T
        
        self.mask_2d = utils.mask_rle_to_2d(self.mask_rle, dx, dy)
        
    
    def mask_2d_to_rle(self, mask_2d):
        """
        Takes a 2D mask of 0/1 and returns the run length encoded form
        """

        mask = mask_2d.T.reshape(-1)  # order by columns and flatten to 1D
        mask_padded = np.pad(mask, 1)  # pad zero on both sides
        #find the start positions of the 1's
        starts = np.where((mask_padded[:-1] == 0) & (mask_padded[1:] == 1))[0]
        #find the end positions of 1's for each run
        ends = np.where((mask_padded[:-1] == 1) & (mask_padded[1:] == 0))[0]

        rle = np.zeros(2*len(starts))
        
        rle[::2] = starts
        #length of each run = end position - start position
        rle[1::2] = ends - starts
        rle = rle.astype(int)
        return rle
    
    def main():
        #setting parameters for tiling and thresholding raw images
        image_size_reduced = 512
        sat_threshold=40
        pixcount_th=200

        #set the directories for input raw images and output tiled images
        datadir_train = '/media/bony/Ganga_HDD_3TB/Ganges_Backup/Machine_Learning/HuBMAP_Hacking_Kidney/hubmap-kidney-segmentation/train'
        masks_train = '/media/bony/Ganga_HDD_3TB/Ganges_Backup/Machine_Learning/HuBMAP_Hacking_Kidney/hubmap-kidney-segmentation/train.csv'
        tiled_threshold_img_dir = os.path.join(datadir_train, 'tiled_thresholded_'+str(image_size_reduced))
        if not os.path.exists(tiled_threshold_img_dir):
            os.makedirs(tiled_threshold_img_dir)

        #read the mask of the images in rle format
        df_train_masks = pd.read_csv(masks_train).set_index('id')

        #get a list of all the raw images
        os.chdir(datadir_train)
        raw_image_files = [f for f in os.listdir(datadir_train) if "tiff" in f]

        mask_tile_dict = {}
        
    count = 0
    for f in raw_image_files:
        raw_file_name = f.split('.')[0]
        print(raw_file_name)
        
        img_raw = tiff.imread(os.path.join(datadir_train, f))
        raw_img = Image(img_raw, img_name = raw_file_name)
        #get the mask for the image in rle format
        raw_img.mask_rle = df_train_masks.loc[raw_file_name, 'encoding']

        raw_img.mask_rle_to_2d() #convert mask from rle to 2D
        #print(raw_img.mask_rle[:100])
        #print(f"shape of raw image: {raw_img.shape}, of mask 2D : {raw_img.mask_2d.shape}")
        #print(raw_img.__dir__())

        raw_img.split_image_mask_into_tiles(sz=image_size_reduced)
        print(f"shape of tiled image: {raw_img.tiled_img.shape}, of tiled mask 2D : {raw_img.tiled_mask.shape}")

        raw_img.save_thresholded_image(tiled_threshold_img_dir, mask_tile_dict, sat_threshold=sat_threshold, pixcount_th=pixcount_th)
        count += 1

    #save the tiled mask dict into a file for later use

    tiled_mask_rle_file = open(os.path.join(tiled_threshold_img_dir, 'tiled_mask_rle'), 'wb')
    pickle.dump(mask_tile_dict, tiled_mask_rle_file)
    tiled_mask_rle_file.close()

if __name__ == "__main__":
    main()  

# Process Inference Data 

In [None]:
from skimage import io, transform
import torch 

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import os
import pickle

import matplotlib.pyplot as plt
from utils import preprocessing_utils as utils


class Dataset_Image_mask(Dataset):

    def __init__(self, data_dir, mean, std_dev, transform=None):
        super(Dataset_Image_mask, self).__init__()
        self.root_dir = data_dir
        self.transform = transform
        self.mask_dict = self.get_mask_dict()
        self.img_name_list = list(self.mask_dict.keys())
        self.len = self.__len__()
        self.normalize = transforms.Normalize(mean, std_dev)

        

    def get_mask_dict(self):
        '''
        open the pickled file containing the dict of mask in rle format
        returns
        dict: key same as imgae file name
        value: numpy array in rle format
        '''
        img_and_mask_files = os.listdir(self.root_dir)
        mask_rle_file = [x for x in img_and_mask_files if "mask" in x][0]
        mask_rle_dict = pickle.load(
            open(os.path.join(self.root_dir, mask_rle_file), 'rb'))
        return mask_rle_dict

   
    def __len__(self):
        return len(self.img_name_list)

    
    def __getitem__(self, idx):
        img_name = self.img_name_list[idx]
        img_file_name = os.path.join(self.root_dir, f"{img_name}.png")
        img = io.imread(img_file_name)
        #img = torch.tensor(img.transpose((2,0,1))).float()
        #img = img.transpose(2, 0, 1)
        
        mask_rle_np = self.mask_dict[img_name]
        #print("mask rle shape :: ", mask_rle_np.shape)
        #print(f"**\n{mask_rle_np}")
        mask_rle = " ".join(str(x) for x in mask_rle_np)
        
        mask_2d = utils.mask_rle_to_2d(mask_rle, 512, 512)
        #augment image for training
        if self.transform is not None:
            img = self.transform(img)
            mask_2d = self.transform(mask_2d)
        
        img = self.normalize(transforms.ToTensor()(img))
        mask = torch.from_numpy(mask_2d).long()

        sample = {'image': img, 'mask': mask, 'idx': img_name}
        #sample = {'image': img}
        return sample


def main():
    data_dir = '/media/bony/Ganga_HDD_3TB/Ganges_Backup/Machine_Learning/HuBMAP_Hacking_Kidney/hubmap-kidney-segmentation/train/tiled_thresholded_512'
    mean = [0.68912, 0.47454, 0.6486]
    std_dev = [0.13275, 0.23647, 0.15536]


    dataset = Dataset_Image_mask(data_dir, mean, std_dev)
    n_tot = dataset.len

    train_test_split = 0.8
    train_count = int(train_test_split * n_tot)

    test_count = dataset.len - train_count

    train_idx = list(np.random.choice(range(n_tot), train_count, replace = False))
    test_idx = list(set(range(n_tot)) - set(train_idx))

    print(len(train_idx), len(test_idx), n_tot - len(train_idx) - len(test_idx))

    train_ds = torch.utils.data.Subset(dataset, train_idx)
    test_ds = torch.utils.data.Subset(dataset, test_idx)

    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
    train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size= 4, shuffle= True, num_workers= 0)
    
    for i, sample_batch in enumerate(test_loader):
        print(i, sample_batch['image'].shape, sample_batch['mask'].shape)

    

    '''
    for i in range(30):
        sample = dataset[i]
        print(sample['image'].shape)
        if i == 83:
            img = sample['image']
            plt.imshow(img)
    '''

if __name__ == '__main__':
    main()

# Model Prep

In [None]:
from base.base_trainer import BaseTrainer
from utils import dataloader, loss, metrics
from models import unet

class Trainer(BaseTrainer):

    def __init__(self, model, loss_fn, config, train_loader, val_loader=None):
        super(Trainer, self).__init__(
            model, loss_fn, config, train_loader, val_loader)

        self.optimizer = loss.use_optimizer(model, config)
        self.bce_dice_ratio = config['loss']["bce_dice_ratio"]

    def _train_epoch(self, epoch):
        '''
        train the model for one epoch
        '''
        self.model.train()
        self._reset_metrics()
        tbar = tqdm(self.train_loader, ncols = 100, miniters=50)

        for i, sample_batch in enumerate(tbar):
            self.optimizer.zero_grad()

            img = sample_batch['image']
            mask = sample_batch['mask'].float()

            batch_size = img.shape[0]

            img = img.to(self.device)
            mask = mask.to(self.device)
            #img = img.transpose((0,3,1,2))
            out = torch.squeeze(self.model(img), 1)
            
            train_loss = self.loss(out, mask, self.bce_dice_ratio)
            self.total_loss.update(train_loss.item(), batch_size)

            train_loss.backward() #perform backprop
            self.optimizer.step() #update parameters

            tbar.set_description( f"Train: Epoch: {epoch}, Avg Loss: {self.total_loss.avg:.5f}" )
            #if (i % 50 == 0):
            #    print(f"epoch: {epoch}, batch : {i}, train loss: {train_loss.item(): .5f}, train average loss: {self.total_loss.avg: .5f}")
            #    break
        return self.total_loss.avg

    
    def _val_epoch(self, epoch):
        if self.val_loader is None:
            print(f"No val loader exists")
            return {}

        self.model.eval()
        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=100)
        with torch.no_grad():
            for i, sample_batch in enumerate(tbar):
                img = sample_batch['image']
                mask = sample_batch['mask'].float()

                batch_size = img.shape[0]
                img = img.to(self.device)
                mask = mask.to(self.device)

                out = torch.squeeze(self.model(img), 1)
                val_loss = self.loss(out, mask, self.bce_dice_ratio)
                self.total_loss.update(val_loss.item(), batch_size)
                #if (i%10 == 0):
                #    print(f"epoch: {epoch}, batch : {i}, val loss: {val_loss.item()}, val average loss: {self.total_loss.avg}")
                tbar.set_description(f"Val: Epoch: {epoch}, Avg Loss: {self.total_loss.avg:.5f}")

        return self.total_loss.avg
    
    
    def _reset_metrics(self):

        self.total_loss = metrics.AverageMeter()
        
        
    def main():
        root_dir = 'C:\Scripts\hubmap\code'

        data_dir = 'C:\Scripts\hubmap\\train\\tiled_thresholded_512'

        mean = [0.68912, 0.47454, 0.6486]
        std_dev = [0.13275, 0.23647, 0.15536]

        #full dataset with training images and masks
        dataset = dataloader.Dataset_Image_mask(data_dir, mean, std_dev)

        n_tot = dataset.len

        #SplitS full dataset into train set and test set
        train_test_split = 0.8
        train_count = int(train_test_split * n_tot)

        test_count = dataset.len - train_count

        train_idx = list(np.random.choice(
            range(n_tot), train_count, replace=False))
        test_idx = list(set(range(n_tot)) - set(train_idx))

        print(len(train_idx), len(test_idx), n_tot - len(train_idx) - len(test_idx))

        train_ds = torch.utils.data.Subset(dataset, train_idx)
        test_ds = torch.utils.data.Subset(dataset, test_idx)

        model = unet.UNet()

        config = json.load(open('config.json'))
        b_size = config["train_loader"]["args"]["batch_size"]
        train_loader = DataLoader(
            train_ds, batch_size=b_size, shuffle=True, num_workers=0)
        b_size = config["val_loader"]["args"]["batch_size"]
        val_loader = DataLoader(
            test_ds, batch_size=b_size, shuffle=True, num_workers=0)

        trainer = Trainer(model, loss.loss_fn, config, train_loader, val_loader)
        print(f"Trainining on device: {trainer.device}")

        trainer.train()

if __name__ == "__main__":
    main()