In [6]:

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler #don't need
#add other imports only as needed

In [7]:

"""
put the helper function rle_to_2d and 2d_to_rle over here
"""

'\nput the helper function rle_to_2d and 2d_to_rle over here\n'

In [None]:
def get_padsize(img, reduce, sz):

    shape = img.shape
    print(shape)

    pad0 = (reduce*sz - shape[0] % (reduce*sz)) % (reduce*sz)
    pad1 = (reduce*sz - shape[1] % (reduce*sz)) % (reduce*sz)
    pad_x = (pad0//2, pad0-pad0//2)
    pad_y = (pad1//2, pad1-pad1//2)

    return pad_x, pad_y

def check_threshold(img_BGR, sat_threshold, pixcount_th):

    """
    checks if an input image passes the threshold conditions:
    conditions:
    not black--> sum of pixels exceeed a threshold = pixcount_th
    saturation --> number of pixels with saturation > sat_threshold exceeds pixcount_th
    Returns:
    True if both conditions are met else False
    """
    #if most of the pixels are black, return False
    #edge of each image is typically black
    if img_BGR.sum() < pixcount_th:
        return False

    #convert to hue, saturation, Value in openCV
    hsv = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    # if less than prefined number of values are above a saturation threshold, return False
    #this is typically the gray background around the biological object
    if (s > sat_threshold).sum() < pixcount_th:
        return False

    return True


In [2]:
class Image():
    def __init__(self, img, img_name =None):
        self.img = img
        self.shape = img.shape
        self.name = img_name

        self.image_reshape()
        self.dx = self.shape[0]
        self.dy = self.shape[1]

        self.tile_size = None

        self.pad_x = None
        self.pad_y = None
        self.tiled_img = None

        self.mask_rle = None
        self.mask_2d = None
        self.tiled_mask = None
        
    def image_reshape(self):
    
        if len(self.shape) == 5:
            self.img = np.transpose(self.img.squeeze(), (1, 2, 0))
            self.shape = self.img.shape
            
    
    def split_image_mask_into_tiles(self, reduce=1, sz=512):
     
        self.tile_size = sz

        self.pad_x, self.pad_y = utils.get_padsize(self.img, reduce, sz)
        print(self.pad_x, self.pad_y)
        #Create padded Image and padded mask2D
        img_padded  = np.pad(self.img, [self.pad_x, self.pad_y, (0, 0)], constant_values=0)
        mask_padded = np.pad(self.mask_2d, [self.pad_x, self.pad_y], constant_values = 0)

        print("shape of image after padding:: ", img_padded.shape,
            img_padded.shape[0]//sz, img_padded.shape[1]//sz)

        print("shape of mask after padding:: ", mask_padded.shape,
              mask_padded.shape[0]//sz, mask_padded.shape[1]//sz)

        #tile the padded image
        img_reshaped = img_padded.reshape(
            img_padded.shape[0]//sz, sz, img_padded.shape[1]//sz, sz, 3)
        img_reshaped = img_reshaped.transpose(0, 2, 1, 3, 4).reshape(-1, sz, sz, 3)

        #tile the padded mask2D
        mask_reshaped = mask_padded.reshape(
            mask_padded.shape[0]//sz, sz, mask_padded.shape[1]//sz, sz)
        mask_reshaped = mask_reshaped.transpose(
            0, 2, 1, 3).reshape(-1, sz, sz)

        self.tiled_img = img_reshaped
        self.tiled_mask = mask_reshaped
        
    
    def save_thresholded_image(self, tiled_threshold_img_dir, mask_tile_dict, sat_threshold=40, pixcount_th=200):
        """
        instead of save, check thresholding of image
        if it passes threshold then do an inference else predict an mask of all zeros
        """
        
        n = self.tiled_img.shape[0]

        valid_img_count = 0
        valid_idx = []
        print(f"Original tiled image count = {n}")

        for i in range(n):
            img_BGR = self.tiled_img[i, :, :, :]
            if utils.check_threshold(img_BGR, sat_threshold, pixcount_th):
                valid_img_count += 1
                valid_idx.append(i)
                
                #create an id for the image tile
                img_tile_id = f"{self.name}_{str(self.tile_size)}_{str(valid_img_count)}_{str(i)}"
                img_name = img_tile_id+'.png'  # name of the saved image tile

                mask_for_tile = self.tiled_mask[i, :, :]  # get the mask for the tile
                #convert the mask for the tile to rle
                mask_rle = self.mask_2d_to_rle(mask_for_tile)
                #save the rle mask to a dict, key = name of the corresponding image tile
                mask_tile_dict[img_tile_id] = mask_rle

                #if valid_img_count == 1001:
                cv2.imwrite(os.path.join(tiled_threshold_img_dir, img_name), img_BGR)

        print(f"Image count after thresholding = {valid_img_count}")


    def mask_rle_to_2d(self):
        """
        converts mask from run length encoding to 2D numpy array
        """
        dx = self.dx
        dy = self.dy

    
        mask = np.zeros(dx*dy, dtype=np.uint8)
        s = self.mask_rle.split()  # split the rle encoding
        for i in range(len(s)//2):
            start = int(s[2*i])-1
            length = int(s[2*i+1])
            mask[start:start+length] = 1
        self.mask_2d = mask.reshape(dy, dx).T
        
        self.mask_2d = utils.mask_rle_to_2d(self.mask_rle, dx, dy)
        
    
    def mask_2d_to_rle(self, mask_2d):
        """
        Takes a 2D mask of 0/1 and returns the run length encoded form
        """

        mask = mask_2d.T.reshape(-1)  # order by columns and flatten to 1D
        mask_padded = np.pad(mask, 1)  # pad zero on both sides
        #find the start positions of the 1's
        starts = np.where((mask_padded[:-1] == 0) & (mask_padded[1:] == 1))[0]
        #find the end positions of 1's for each run
        ends = np.where((mask_padded[:-1] == 1) & (mask_padded[1:] == 0))[0]

        rle = np.zeros(2*len(starts))
        
        rle[::2] = starts
        #length of each run = end position - start position
        rle[1::2] = ends - starts
        rle = rle.astype(int)
        return rle



In [4]:
## upload model folder - Rudra

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn as nn
import numpy as np


class BaseModel(nn.Module):
    
    def __init__(self):
        super(BaseModel, self).__init__()

    def forward(self):
        raise NotImplementedError

    def __str__(self):
        model_params = filter(lambda x: x.requires_grad, self.parameters())

        return super(BaseModel, self).__str__()


class Conv2x(nn.Module):
    '''
    preserves the the size of the image
    '''
    def __init__(self, in_ch, out_ch, inner_ch=None):
        super(Conv2x, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.inner_ch = out_ch//2 if inner_ch is None else inner_ch

        self.conv2d_1 = nn.Conv2d(self.in_ch, self.inner_ch,
                                  kernel_size=3, padding=1, bias=False)
        self.conv2d_2 = nn.Conv2d(self.inner_ch, self.out_ch,
                                  kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inner_ch)
        self.bn2 = nn.BatchNorm2d(self.out_ch)

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2d_2(x)
        x = self.bn2(x)
        x = F.relu(x)

        return x


class encoder(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(encoder, self).__init__()
        self.conv2x = Conv2x(in_ch, out_ch)
        self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)

    def forward(self, x):
        x = self.conv2x(x)
        x = self.pool(x)
        return x

class decoder(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(decoder, self).__init__()
        self.transposeconv = nn.ConvTranspose2d(
            in_ch, in_ch//2, kernel_size=2, stride=2)
        self.conv2x = Conv2x(in_ch, out_ch)

    def forward(self, x_down, x_up, interpolate=True):

        x_up = self.transposeconv(x_up)

        #check for matching dims before concatenating

        if (x_up.size(2) != x_up.size(2)) or (x_up.size(3) != x_up.size(3)):
            if interpolate:
                x_up = F.interpolate(x_up, size=(x_down.size(2), x_down.size(3)),
                mode="bilinear", align_corners=True)
        
        #Concat features from down conv channel and current up-conv
        #along channel dim =1
        x_up = torch.cat([x_up, x_down], dim=1) 
        x_up = self.conv2x(x_up)

        return x_up

class UNet(BaseModel):

    def __init__(self, in_ch=3, conv_channels=[16, 32, 64, 128, 256]):
        super(UNet, self).__init__()

        self.conv_channels = conv_channels
        self.conv_start = Conv2x(in_ch, conv_channels[0]) #output_size = input_size
        self.down1 = encoder(conv_channels[0], conv_channels[1])   #output_size = input_size/2
        self.down2 = encoder(conv_channels[1], conv_channels[2])   #output_size = input_size/2
        self.down3 = encoder(conv_channels[2], conv_channels[3])   #output_size = input_size/2
        self.down4 = encoder(conv_channels[3], conv_channels[4])   #output_size = input_size/2

        self.conv_middle = Conv2x(conv_channels[4], conv_channels[4]) #output_size = input_size

        self.up4 = decoder(conv_channels[4], conv_channels[3]) #output_size = input_size * 2
        self.up3 = decoder(conv_channels[3], conv_channels[2]) #output_size = input_size * 2
        self.up2 = decoder(conv_channels[2], conv_channels[1]) #output_size = input_size * 2
        self.up1 = decoder(conv_channels[1], conv_channels[0]) #output_size = input_size * 2

        self.final_conv = nn.Conv2d(self.conv_channels[0], 1, kernel_size=1)

        self.init_params()
    
    def init_params(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()

            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()



    def forward(self, x):
        # size of x = [B, _, nx, ny]
        
        x1 = self.conv_start(x)  # size of x = [B, self.conv_channels[0], nx, ny]
        x2 = self.down1(x1)  # size of x = [B, self.conv_channels[1], nx/2, ny/2]
        x3 = self.down2(x2)  # size of x = [B, self.conv_channels[2], nx/4, ny/4]
        x4 = self.down3(x3)  # size of x = [B, self.conv_channels[3], nx/8, ny/8]
        x5 = self.down4(x4)  # size of x = [B, self.conv_channels[4], nx/16, ny/16]

        x = self.conv_middle(x5)  # size of x = [B, self.conv_channels[4], nx/16, ny/16]

        x = self.up4(x4, x)       # size of x = [B, self.conv_channels[3], nx/8, ny/8]
        x = self.up3(x3, x)       # size of x = [B, self.conv_channels[2], nx/4, ny/4]
        x = self.up2(x2, x)       # size of x = [B, self.conv_channels[1], nx/2, ny/2]
        x = self.up1(x1, x)       # size of x = [B, self.conv_channels[0], nx, ny]

        x = self.final_conv(x)

        return x

In [8]:
# copy the base model class and classes in unet.py . Do not change anything - Ethan

In [None]:
#upload the checkpoint - Rudra

In [None]:
def metric_dice_iou(output, target, smooth = 0.005):
        tp = (output * target).sum(axis=(1,2)) #intersection
        fp = (output * (1.0 - target)).sum(axis=(1,2)) #false positives
        fn = ((1.0 - output) * target).sum(axis=(1,2)) #false negatives
        dice = np.mean((2.0 * tp + smooth) / (2 * tp + fp + fn + smooth))
        iou = np.mean((tp + smooth) / (tp + fp + fn + smooth))

        return dice, iou

In [None]:
def ineference_single_image(image):
        if self.val_loader is None:
            print(f"No val loader exists")
            return {}

        model = unet.UNet()
        
        #resume checkpoint , load checkpoint
        #find where to store the model checkpoint
        
        self.model.eval()
        self._reset_metrics()
        tbar = tqdm(self.val_loader, ncols=100)
        
        #read an image
        img_raw = tiff.imread(os.path.join(datadir_train, f))
        
        #instantiate image class
        Image(img_raw)
        #call Image.split_image_mask_into_tiles(self, reduce=1, sz=512) to create
        
        with torch.no_grad():
            #loop over the tiles
            for i in range(tile_count)
            #for i, sample_batch in enumerate(tbar):
                #for each tile
                #check thresolhold
                #if thresold passes, then do inference
                img = sample_batch['image']
                #mask = sample_batch['mask'].float()

                batch_size = img.shape[0]
                img = img.to(self.device)
                #mask = mask.to(self.device)

                out = torch.squeeze(self.model(img), 1)
                
        
                #accumulate the predictions into a big np array
                #prediction size = [total_tiles, sz, sz]
                
        # call Image.reconstruct_original_from_padded_tiled_image(self, tiled_predicted_mask)
        # to get the predicted mask as the same size of the original image / mask
        
        #calculate metrics
        #calculate 2d_to_rle
        
            

        return mask_pred_2d_to_rle

#call  this function for all the prediction images

In [None]:
def _resume_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path) #need this line

        last_epoch = checkpoint['epoch']
        model_name = checkpoint['config']['name']

        if model_name == self.config['name']:
            self.model.load_state_dict(checkpoint['state_dict']) #need this line
            self.start_epoch = last_epoch + 1
            return True
        else:
            print("current model name doesn't match with previously saved model name !!")
            print("Currennt model name: {} , previous model name : {}".format(self.config['name'], model_name))
            return False

In [None]:
#test on one image we trained
#comapre the prediction to mask label
#check prediction dimension

In [None]:
https://www.atlassian.com/git/tutorials
    