<a href="https://colab.research.google.com/github/vikram71198/MSRF_Net/blob/master/msrf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%cd '/content/drive/MyDrive/Colab Notebooks/'
!unzip -o cvc-final.zip -d /content/
%cd '/content/'

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/out/mask/479_3.png  
  inflating: /content/out/mask/479_4.png  
  inflating: /content/out/mask/479_5.png  
  inflating: /content/out/mask/479_6.png  
  inflating: /content/out/mask/479_7.png  
  inflating: /content/out/mask/479_8.png  
  inflating: /content/out/mask/479_9.png  
  inflating: /content/out/mask/47_1.png  
  inflating: /content/out/mask/47_10.png  
  inflating: /content/out/mask/47_11.png  
  inflating: /content/out/mask/47_12.png  
  inflating: /content/out/mask/47_13.png  
  inflating: /content/out/mask/47_14.png  
  inflating: /content/out/mask/47_15.png  
  inflating: /content/out/mask/47_16.png  
  inflating: /content/out/mask/47_17.png  
  inflating: /content/out/mask/47_18.png  
  inflating: /content/out/mask/47_19.png  
  inflating: /content/out/mask/47_2.png  
  inflating: /content/out/mask/47_20.png  
  inflating: /content/out/mask/47_21.png  
  inflating: /content/out/mask/47_

In [3]:
!pip install torchgeometry

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from torchgeometry.losses.dice import DiceLoss
from torchvision.models.resnet import BasicBlock
import cv2



Defining Device

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Utils

In [5]:
import cv2
import glob
import numpy as np
from tqdm import tqdm as tqdm
import pickle as pkl
import os

'''
Reads the image specified by 'path' and returns it
param : path - path of image file
return : image as a numpy array
'''
def read_img(path):
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = np.clip(image - np.median(image)+127, 0, 255)
    image = image/255.0
    image = image.astype(np.float32)
    return image

def read_mask(path):
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    thresh, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    mask = mask/255.0
    mask = mask.astype(np.float32)
    #mask = np.expand_dims(mask, axis=-1)
    return mask

'''
Converts numpy img to tensor
param : img - numpy arr containing image data
return : t - torch tensor of shape [1, 3, H, W]
'''
def img_to_tensor(img):
    t = torch.from_numpy(img)
    t = t.view(-1, 3, t.shape[0], t.shape[1])
    return t

def mask_to_tensor(mask):
    t = torch.from_numpy(mask)
    t = t.view(-1, t.shape[0], t.shape[1])
    return t.long()

'''
t - tensor of shape [H, W]
'''
def tensor_to_mask(t):
    t = t.view(t.shape[0], t.shape[1])
    return t.numpy()

'''
Converts tensor back to numpy img
param : t - torch tensor of shape [1, 3, H, W]
return : img - numpy arr containing image data
'''
def tensor_to_img(t):
    t = t.view(t.shape[2], t.shape[3], 3)
    return t.numpy()

Deep Learning Architecture

In [6]:
"""
Squeeze & Excitation Block
"""
class SE_Block(nn.Module):
    def __init__(self, in_ch, ratio = 16):
        super().__init__()
        self.block = nn.Sequential(nn.Linear(in_ch, in_ch//ratio), nn.ReLU(), nn.Linear(in_ch//ratio, in_ch), nn.Sigmoid())
    
    def forward(self, x):
        y = x.mean((-2,-1))
        y = self.block(y).unsqueeze(-1).unsqueeze(-1)
        return x*y
"""
Encoder Block
"""
class Encoder(nn.Module):
    def __init__(self, in_ch, init_feat = 32):

        super().__init__()

        '''Instantiations of all subclasses of nn.Module will be callable objects because nn.Module has __call__() built-in (which is inherited by the subclass)
        which in turn calls forward(). So, if forward() is overridden in the subclass, the new forward() will be called!'''

        self.enc1 = nn.Sequential(nn.Conv2d(in_ch, init_feat, kernel_size = (3,3), stride = (1,1), padding = (1,1)),
        nn.ReLU(), nn.Conv2d(init_feat, init_feat, kernel_size = (3,3), stride = (1,1), padding = (1,1)), nn.ReLU(),
        nn.BatchNorm2d(init_feat), SE_Block(init_feat, ratio = init_feat // 2))

        self.enc2 = nn.Sequential(nn.MaxPool2d(kernel_size = (2,2)), nn.Dropout(0.2),
        nn.Conv2d(init_feat, init_feat*2, kernel_size = (3,3), stride = (1,1), padding = (1,1)),
        nn.ReLU(),
        nn.Conv2d(init_feat*2, init_feat*2, kernel_size = (3,3), stride = (1,1), padding = (1,1)),
        nn.ReLU(),
        nn.BatchNorm2d(init_feat*2), SE_Block(init_feat*2, ratio = init_feat // 2))

        self.enc3 = nn.Sequential(nn.MaxPool2d(kernel_size = (2,2)), nn.Dropout(0.2),
        nn.Conv2d(init_feat*2, init_feat*4, kernel_size = (3,3), stride = (1,1), padding = (1,1)),
        nn.ReLU(),
        nn.Conv2d(init_feat*4, init_feat*4, kernel_size = (3,3), stride = (1,1), padding = (1,1)),
        nn.ReLU(),
        nn.BatchNorm2d(init_feat*4), SE_Block(init_feat*4, ratio = init_feat // 2))

        self.enc4 = nn.Sequential(nn.MaxPool2d(kernel_size = (2,2)), nn.Dropout(0.2),
        nn.Conv2d(init_feat*4, init_feat*8, kernel_size = (3,3), stride = (1,1), padding = (1,1)),
        nn.ReLU(),
        nn.Conv2d(init_feat*8, init_feat*8, kernel_size = (3,3), stride = (1,1), padding = (1,1)),
        nn.ReLU(),
        nn.BatchNorm2d(init_feat*8))

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)
        return x1, x2, x3, x4
        
"""
DSDF (Dual Scale Dense Fusion) Block
"""
class DSDF(nn.Module):
    def __init__(self, in_ch_x, in_ch_y, nf1 = 128, nf2 = 256, gc = 64, bias = True):
        super().__init__()

        self.nx1 = nn.Sequential(nn.Conv2d(in_ch_x, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1),bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.ny1 = nn.Sequential(nn.Conv2d(in_ch_y, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.nx1c = nn.Sequential(nn.Conv2d(in_ch_x, gc, kernel_size = (4,4), stride = (2,2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.ny1t = nn.Sequential(nn.ConvTranspose2d(in_ch_y, gc, kernel_size = (4,4), stride = (2, 2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.nx2 = nn.Sequential(nn.Conv2d(in_ch_x + gc + gc, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1),bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.ny2 = nn.Sequential(nn.Conv2d(in_ch_y + gc + gc, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1),bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.nx2c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size = (4,4), stride = (2,2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.ny2t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size = (4,4), stride = (2,2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.nx3 = nn.Sequential(nn.Conv2d(in_ch_x + gc + gc + gc, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.ny3 = nn.Sequential(nn.Conv2d(in_ch_y + gc + gc + gc, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.nx3c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size = (4,4), stride = (2,2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.ny3t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size = (4,4), stride = (2,2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.nx4 = nn.Sequential(nn.Conv2d(in_ch_x + gc + gc + gc + gc, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.ny4 = nn.Sequential(nn.Conv2d(in_ch_y + gc + gc + gc + gc, gc, kernel_size = (3,3), stride = (1,1), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.nx4c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size = (4,4), stride = (2,2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        #TO DO
        self.ny4t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size = (4,4), stride = (2,2), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.nx5 = nn.Sequential(nn.Conv2d(in_ch_x + gc + gc + gc + gc + gc, nf1, kernel_size = (3,3), stride = (1,1), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))

        self.ny5 = nn.Sequential(nn.Conv2d(in_ch_y + gc + gc + gc + gc + gc, nf2, kernel_size = (3,3), stride = (1,1), padding = (1,1), bias = bias),
        nn.LeakyReLU(negative_slope = 0.25))


    def forward(self, x, y):
        x1 = self.nx1(x)
        y1 = self.ny1(y)

        x1c = self.nx1c(x)
        y1t = self.ny1t(y)

        x2_input = torch.cat([x, x1, y1t], dim = 1)
        x2 = self.nx2(x2_input)

        y2_input = torch.cat([y, y1, x1c], dim = 1)
        y2 = self.ny2(y2_input)

        x2c = self.nx2c(x1)
        y2t = self.ny2t(y1)

        x3_input = torch.cat([x, x1, x2, y2t], dim = 1)
        x3 = self.nx3(x3_input)

        y3_input = torch.cat([y, y1, y2, x2c], dim = 1)
        y3 = self.ny3(y3_input)

        x3c = self.nx3c(x2)
        y3t = self.ny3t(y2)

        x4_input = torch.cat([x, x1, x2, x3, y3t], dim = 1)
        x4 = self.nx4(x4_input)

        y4_input = torch.cat([y, y1, y2, y3, x3c], dim = 1)
        y4 = self.ny4(y4_input)

        x4c = self.nx4c(x3)
        y4t = self.ny4t(y3)

        x5_input = torch.cat([x, x1, x2, x3, x4, y4t], dim = 1)
        x5 = self.nx5(x5_input)

        y5_input = torch.cat([y, y1, y2, y3, y4, x4c], dim = 1)
        y5 = self.ny5(y5_input)
        
        x5 *= 0.4
        y5 *= 0.4

        return x5 + x, y5 + y
"""
MSRF Sub-Network implementing Multi-Scale Fusion using DSDF Blocks
"""
class MSRF_SubNet(nn.Module):
    def __init__(self, init_feat):
        super().__init__() 
        self.dsfs_1  = DSDF(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
        self.dsfs_2  = DSDF(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)
        self.dsfs_3  = DSDF(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
        self.dsfs_4  = DSDF(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)
        self.dsfs_5  = DSDF(init_feat*2, init_feat*4, nf1=init_feat*2, nf2=init_feat*4, gc=init_feat*2//2)
        self.dsfs_6  = DSDF(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
        self.dsfs_7  = DSDF(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)
        self.dsfs_8  = DSDF(init_feat*2, init_feat*4, nf1=init_feat*2, nf2=init_feat*4, gc=init_feat*2//2)
        self.dsfs_9  = DSDF(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
        self.dsfs_10 = DSDF(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)

    def forward(self, x11, x21, x31, x41):
        x12, x22 = self.dsfs_1(x11, x21)
        x32, x42 = self.dsfs_2(x31, x41)
        x12, x22 = self.dsfs_3(x12, x22)
        x32, x42 = self.dsfs_4(x32, x42)
        x22, x32 = self.dsfs_5(x22, x32)
        x13, x23 = self.dsfs_6(x12, x22)
        x33, x43 = self.dsfs_7(x32, x42)
        x23, x33 = self.dsfs_8(x23, x33)
        x13, x23 = self.dsfs_9(x13, x23)
        x33, x43 = self.dsfs_10(x33, x43)

        x13 = (x13*0.4) + x11
        x23 = (x23*0.4) + x21
        x33 = (x33*0.4) + x31
        x43 = (x43*0.4) + x41

        return x13, x23, x33, x43

"""
Gated Convolutions
"""
class GatedConv(nn.Conv2d):
    def __init__(self, in_channels, out_channels):
        super().__init__(in_channels, out_channels, 1, bias=False)
        self.attention = nn.Sequential(
            nn.BatchNorm2d(in_channels + 1),
            nn.Conv2d(in_channels + 1, in_channels + 1, 1),
            nn.ReLU(),
            nn.Conv2d(in_channels + 1, 1, 1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, feat, gate):
        attention = self.attention(torch.cat((feat, gate), dim=1))
        out = F.conv2d(feat * (attention + 1), self.weight)
        return out
"""
Shape Stream
"""
class ShapeStream(nn.Module):
    def __init__(self, init_feat):
        super().__init__()
        self.res2_conv = nn.Conv2d(init_feat * 2, 1, 1)
        self.res3_conv = nn.Conv2d(init_feat * 4, 1, 1)
        self.res4_conv = nn.Conv2d(init_feat * 8, 1, 1)
        self.res1 = BasicBlock(init_feat, init_feat, 1)
        self.res2 = BasicBlock(32, 32, 1)
        self.res3 = BasicBlock(16, 16, 1)
        self.res1_pre = nn.Conv2d(init_feat, 32, 1)
        self.res2_pre = nn.Conv2d(32, 16, 1)
        self.res3_pre = nn.Conv2d(16, 8, 1)
        self.gate1 = GatedConv(32, 32)
        self.gate2 = GatedConv(16, 16)
        self.gate3 = GatedConv(8, 8)
        self.gate = nn.Conv2d(8, 1, 1, bias=False)
        self.fuse = nn.Conv2d(2, 1, 1, bias=False)
    
    def forward(self, x, res2, res3, res4, grad):
        size = grad.shape[-2:]
        x = F.interpolate(x, size, mode='bilinear', align_corners=True)
        res2 = F.interpolate(self.res2_conv(res2), size, mode='bilinear', align_corners=True)
        res3 = F.interpolate(self.res3_conv(res3), size, mode='bilinear', align_corners=True)
        res4 = F.interpolate(self.res4_conv(res4), size, mode='bilinear', align_corners=True)
        gate1 = self.gate1(self.res1_pre(self.res1(x)), res2)
        gate2 = self.gate2(self.res2_pre(self.res2(gate1)), res3)
        gate3 = self.gate3(self.res3_pre(self.res3(gate2)), res4)
        gate = torch.sigmoid(self.gate(gate3))
        feat = torch.sigmoid(self.fuse(torch.cat((gate, grad), dim=1)))
        return gate, feat

class AttentionBlock(nn.Module):
    def __init__(self, in_ch_x, in_ch_g, med_ch):
        super().__init__()
        self.theta = nn.Conv2d(in_ch_x, med_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True)
        self.phi = nn.Conv2d(in_ch_g, med_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
        self.block = nn.Sequential(nn.ReLU(), nn.Conv2d(med_ch, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
                                   nn.Sigmoid(), nn.ConvTranspose2d(1, 1, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True))
        self.batchnorm = nn.BatchNorm2d(in_ch_x)

    def forward(self, x, g):
        theta = self.theta(x) + self.phi(g)
        out = self.batchnorm(self.block(theta) * x)
        return out

class UpBlock(nn.Module):
    def __init__(self, inp1_ch, inp2_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(inp2_ch, inp1_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True)
    
    def forward(self, input_1, input_2):
        x = torch.cat([self.up(input_2), input_1], dim=1)
        return x

class SpatialATTBlock(nn.Module):
    def __init__(self, in_ch, med_ch):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_ch, med_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
                                   nn.BatchNorm2d(med_ch),
                                   nn.ReLU(),
                                   nn.Conv2d(med_ch, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
                                   nn.Sigmoid())
    def forward(self, x):
        x = self.block(x)
        return x

class DualATTBlock(nn.Module):
    def __init__(self, skip_in_ch, prev_in_ch, out_ch):
        super().__init__()
        self.prev_block = nn.Sequential(nn.ConvTranspose2d(prev_in_ch, out_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True),
                                        nn.BatchNorm2d(out_ch),
                                        nn.ReLU())
        self.block = nn.Sequential(nn.Conv2d(skip_in_ch+out_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
                                   nn.BatchNorm2d(out_ch),
                                   nn.ReLU())
        self.se_block = SE_Block(out_ch, ratio=16)
        self.spatial_att = SpatialATTBlock(out_ch, out_ch)
    
    def forward(self, skip, prev):
        prev = self.prev_block(prev)
        x = torch.cat([skip, prev], dim=1)
        inpt_layer = self.block(x)
        se_out = self.se_block(inpt_layer)
        sab = self.spatial_att(inpt_layer) + 1

        return sab * se_out

class Decoder(nn.Module):
    def __init__(self, init_feat, n_classes):
        super().__init__()
        # Stage 1
        self.att_1 = AttentionBlock(init_feat*4, init_feat*8, init_feat*8)
        self.up_1 = UpBlock(init_feat*4, init_feat*8)
        self.dualatt_1 = DualATTBlock(init_feat*4, init_feat*8, init_feat*4)
        self.n34_t = nn.Conv2d(init_feat * 4 + init_feat * 8, init_feat * 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.dec_block_1 = nn.Sequential(nn.BatchNorm2d(init_feat*4),
                                         nn.ReLU(),
                                         nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                         nn.BatchNorm2d(init_feat*4),
                                         nn.ReLU(),
                                         nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                                         )
        self.head_dec_1 = nn.Sequential(nn.Conv2d(init_feat*4, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
                                        #nn.Sigmoid(), #TODO : Inform spanish guy!
                                        nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True))

        # Stage 2
        self.att_2 = AttentionBlock(init_feat * 2, init_feat * 4, init_feat * 2)
        self.up_2 = UpBlock(init_feat * 2, init_feat * 4)
        self.dualatt_2 = DualATTBlock(init_feat * 2, init_feat * 4, init_feat * 2)
        self.n24_t = nn.Conv2d(init_feat * 2 + init_feat * 4, init_feat * 2, kernel_size=(1, 1), stride=(1, 1), padding=(0,0))
        self.dec_block_2 = nn.Sequential(nn.BatchNorm2d(init_feat * 2),
                                         nn.ReLU(),
                                         nn.Conv2d(init_feat * 2, init_feat * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                         nn.BatchNorm2d(init_feat * 2),
                                         nn.ReLU(),
                                         nn.Conv2d(init_feat*2, init_feat * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                                         )
        self.head_dec_2 = nn.Sequential(nn.Conv2d(init_feat * 2, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
                                        #nn.Sigmoid(), #TODO : Inform spanish guy!
                                        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))

        # Stage 3
        self.up_3 = nn.ConvTranspose2d(init_feat * 2, init_feat, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        self.n14_input = nn.Sequential(nn.Conv2d(init_feat + init_feat + 1, init_feat, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
                                       nn.ReLU()) #TODO : This ain't in the paper!
        self.dec_block_3 = nn.Sequential(nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #TODO : Missing 1x1 Conv, ReLu before this
                                         nn.ReLU(),
                                         nn.BatchNorm2d(init_feat))

        self.head_dec_3 = nn.Sequential(nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                        nn.ReLU(),
                                        nn.Conv2d(init_feat, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)))
                                        #nn.Sigmoid()) #TODO : Inform spanish guy!
        
    def forward(self, x13, x33, x43, x23, canny_feat):

        # Stage 1
        x34_preinput = self.att_1(x33, x43)
        x34 = self.up_1(x34_preinput, x43)
        x34_t = self.dualatt_1(x33, x43)
        x34_t = torch.cat([x34, x34_t], dim=1)
        x34_t = self.n34_t(x34_t)
        x34 = self.dec_block_1(x34_t) + x34_t
        pred_1 = self.head_dec_1(x34)

        # Stage 2
        x24_preinput = self.att_2(x23, x34)
        x24 = self.up_2(x24_preinput, x34)
        x24_t = self.dualatt_2(x23, x34)
        x24_t = torch.cat([x24, x24_t], dim=1)
        x24_t = self.n24_t(x24_t)
        x24 = self.dec_block_2(x24_t) + x24_t
        pred_2 = self.head_dec_2(x24)

        # Stage 3
        x14_preinput = self.up_3(x24)
        x14_input = torch.cat([x14_preinput, x13, canny_feat], dim=1)
        x14_input = self.n14_input(x14_input)
        x14 = self.dec_block_3(x14_input)
        x14 = x14 + x14_input
        pred_3 = self.head_dec_3(x14)

        return pred_1, pred_2, pred_3

class MSRF(nn.Module):
    def __init__(self, in_ch, n_classes, init_feat = 32):
        super().__init__()
        self.encoder1 = Encoder(in_ch, init_feat)
        self.msrf_subnet = MSRF_SubNet(init_feat)
        self.shape_stream = ShapeStream(init_feat)
        self.decoder = Decoder(init_feat, n_classes)

    def forward(self, x, canny):
        e1, e2, e3, e4 = self.encoder1(x)
        x13, x23, x33, x43 = self.msrf_subnet(e1, e2, e3, e4)
        canny_gate, canny_feat = self.shape_stream(x13, x23, x33, x43, canny)
        pred_1, pred_2, pred_3 = self.decoder(x13, x33, x43, x23, canny_feat)
        return pred_1, pred_2, pred_3, canny_gate

Data Pre-processing

In [7]:
img_list = [] #sorted(glob.glob("out/image/*"))
mask_list = [] #sorted(glob.glob("out/mask/*"))

for i in range(1, 5):
    img_list += sorted(glob.glob("out/image/*_" + str(i) + ".png"))
    mask_list += sorted(glob.glob("out/mask/*_" + str(i) + ".png"))

print(len(img_list))
print(len(mask_list))

img_data = list(zip(img_list,mask_list))

data_len = len(img_data)
print(data_len)

2448
2448
2448


Hyperparameters

In [8]:
learning_rate = 1e-4
num_epochs = 50
batch_size = 3

Loading Train-Test-Validation Splits

In [9]:
import math

print("Using same train-val-test split as Double U-Net")
with open("/content/drive/MyDrive/Colab Notebooks/train_set.pkl", 'rb') as f:
    train_set = pkl.load(f)
with open("/content/drive/MyDrive/Colab Notebooks/val_set.pkl", 'rb') as f:
    val_set = pkl.load(f)
with open("/content/drive/MyDrive/Colab Notebooks/test_set.pkl", 'rb') as f:
    test_set = pkl.load(f)

num_batches = math.ceil(len(train_set)/batch_size)

Using same train-val-test split as Double U-Net


Loading data batchwise

In [10]:
#Divide Train Data Into List of Batches for Training Loop
train_loader_x = []
train_loader_y = []

for idx in range(0, len(train_set), batch_size):
  if idx + batch_size > len(train_set):
    x_tup, y_tup = list(zip(*(list(train_set)[idx:])))
  else:
    x_tup, y_tup = list(zip(*(list(train_set)[idx:idx + batch_size])))
  train_loader_x.append(x_tup)
  train_loader_y.append(y_tup)

Defining Combined Losses

In [11]:
class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce_loss   = nn.CrossEntropyLoss()
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, pred_1, pred_2, pred_3, pred_canny, msk, canny_label):
        loss_pred_1 = self.ce_loss(pred_1, msk) + self.dice_loss(pred_1, msk)
        loss_pred_2 = self.ce_loss(pred_2, msk) + self.dice_loss(pred_2, msk)
        loss_pred_3 = self.ce_loss(pred_3, msk) + self.dice_loss(pred_3, msk)
        loss_canny = self.bce_loss(pred_canny, canny_label)
        loss = loss_pred_3 + loss_pred_1 + loss_pred_2 + loss_canny

        return loss

Define Optimizer, Loss function

In [12]:
#tmp = torch.ones(batch_size, 3, 288, 384)
print(torch.cuda.memory_allocated() / (1024 * 1024))
with torch.no_grad():
  msrf_net = MSRF(in_ch = 3, n_classes = 2, init_feat = 32).to(device)

#Loading partially trained model
#msrf_net.load_state_dict(torch.load("/content/drive/MyDrive/Colab Notebooks/double_unet_cvc-clinic_10_33.27252979576588.pt"))

print(torch.cuda.memory_allocated() / (1024 * 1024))
# for parameter in double_u_net.parameters():
#    print(f"Parameter = {parameter}")

optimizer = optim.Adam(msrf_net.parameters(), lr = learning_rate, betas=(0.9, 0.999), eps=1e-8)

criterion  = CombinedLoss()

!nvidia-smi

0.0
86.33837890625
Fri Apr 29 16:58:41 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P0    71W / 149W |    573MiB / 11441MiB |      9%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+--------------------------------------------------------------------

In [13]:
for epochs in tqdm(range(num_epochs)):
    running_loss = 0
    for idx in tqdm(range(num_batches)):
        # Array of CV2 images
        cv2_img_data = [read_img(ele) for ele in train_loader_x[idx]]

        img_data = [img_to_tensor(cv2_img) for cv2_img in cv2_img_data]
        img_data = torch.cat(img_data, dim = 0).to(device)
        
        canny_data = [torch.FloatTensor(np.asarray(cv2.Canny(cv2.imread(ele), 10, 100), np.float32)/255.0) for ele in train_loader_x[idx]]
        canny_data = [t.view(-1, t.shape[0], t.shape[1]) for t in canny_data]
        canny_data = torch.stack(canny_data, dim = 0).to(device)

        mask_data = [mask_to_tensor(read_mask(ele)) for ele in train_loader_y[idx]]
        #mask_data = torch.stack(mask_data, dim = 0).to(device)
        mask_data = torch.cat(mask_data, dim = 0).to(device)

        # print(f'Image shape = {img_data.shape}')
        # print(f'Canny shape = {canny_data.shape}')
        # print(f'Mask shape = {mask_data.shape}')

        # print('Before forward pass ->')
        # print(torch.cuda.memory_allocated() / (1024 * 1024))
        
        pred_1, pred_2, pred_3, pred_canny = msrf_net.forward(img_data.float(), canny_data)
        # print(pred_1.shape)
        # print(pred_2.shape)
        # print(pred_3.shape)
        # print(pred_canny.shape)

        # print('After forward pass ->')
        # print(torch.cuda.memory_allocated() / (1024 * 1024))

        del img_data
        loss = criterion(pred_1, pred_2, pred_3, pred_canny, mask_data, canny_data)
        del pred_1, pred_2, pred_3, pred_canny, mask_data, canny_data
        optimizer.zero_grad()
        loss.backward()

        # print('After backward pass ->')
        # print(torch.cuda.memory_allocated() / (1024 * 1024))

        optimizer.step()
        running_loss += float(loss.detach())
        del loss
        #losses.append(running_loss)
        #torch.cuda.empty_cache()
        #!nvidia-smi

        print('*************')

    print(f"For epoch {epochs + 1}, MSRF Loss is {running_loss}")
    if (epochs + 1) % 5 == 0:
        torch.save(msrf_net.state_dict(), f'/content/drive/MyDrive/Colab Notebooks/msrf_cvc-clinic_{epochs+1}_{running_loss}.pt')

# Save PyTorch model to disk
torch.save(msrf_net.state_dict(), '/content/drive/MyDrive/Colab Notebooks/msrf_cvc-clinic.pt')

  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/653 [00:00<?, ?it/s][A
  0%|          | 1/653 [00:04<49:17,  4.54s/it][A

*************



  0%|          | 2/653 [00:08<46:31,  4.29s/it][A

*************



  0%|          | 3/653 [00:12<43:51,  4.05s/it][A

*************



  1%|          | 4/653 [00:16<42:31,  3.93s/it][A

*************



  1%|          | 5/653 [00:19<41:49,  3.87s/it][A

*************



  1%|          | 6/653 [00:23<41:19,  3.83s/it][A

*************



  1%|          | 7/653 [00:27<41:02,  3.81s/it][A

*************



  1%|          | 8/653 [00:31<40:48,  3.80s/it][A

*************



  1%|▏         | 9/653 [00:34<40:35,  3.78s/it][A

*************



  2%|▏         | 10/653 [00:38<40:27,  3.77s/it][A

*************



  2%|▏         | 11/653 [00:42<40:18,  3.77s/it][A

*************



  2%|▏         | 12/653 [00:46<40:12,  3.76s/it][A

*************



  2%|▏         | 13/653 [00:49<40:06,  3.76s/it][A

*************



  2%|▏         | 14/653 [00:53<40:00,  3.76s/it][A

*************



  2%|▏         | 15/653 [00:57<39:55,  3.75s/it][A

*************



  2%|▏         | 16/653 [01:01<39:51,  3.75s/it][A

*************



  3%|▎         | 17/653 [01:04<39:46,  3.75s/it][A

*************



  3%|▎         | 18/653 [01:08<39:42,  3.75s/it][A

*************



  3%|▎         | 19/653 [01:12<39:38,  3.75s/it][A

*************



  3%|▎         | 20/653 [01:16<39:33,  3.75s/it][A

*************



  3%|▎         | 21/653 [01:20<39:32,  3.75s/it][A

*************



  3%|▎         | 22/653 [01:23<39:29,  3.75s/it][A

*************



  4%|▎         | 23/653 [01:27<39:25,  3.75s/it][A

*************



  4%|▎         | 24/653 [01:31<39:20,  3.75s/it][A

*************



  4%|▍         | 25/653 [01:35<39:19,  3.76s/it][A

*************



  4%|▍         | 26/653 [01:38<39:13,  3.75s/it][A

*************



  4%|▍         | 27/653 [01:42<39:07,  3.75s/it][A

*************



  4%|▍         | 28/653 [01:46<39:06,  3.75s/it][A

*************



  4%|▍         | 29/653 [01:50<39:01,  3.75s/it][A

*************



  5%|▍         | 30/653 [01:53<38:59,  3.75s/it][A

*************



  5%|▍         | 31/653 [01:57<38:54,  3.75s/it][A

*************



  5%|▍         | 32/653 [02:01<38:49,  3.75s/it][A

*************



  5%|▌         | 33/653 [02:05<38:45,  3.75s/it][A

*************



  5%|▌         | 34/653 [02:08<38:43,  3.75s/it][A

*************



  5%|▌         | 35/653 [02:12<38:37,  3.75s/it][A

*************



  6%|▌         | 36/653 [02:16<38:36,  3.75s/it][A

*************



  6%|▌         | 37/653 [02:20<38:30,  3.75s/it][A

*************



  6%|▌         | 38/653 [02:23<38:26,  3.75s/it][A

*************



  6%|▌         | 39/653 [02:27<38:22,  3.75s/it][A

*************



  6%|▌         | 40/653 [02:31<38:18,  3.75s/it][A

*************



  6%|▋         | 41/653 [02:35<38:13,  3.75s/it][A

*************



  6%|▋         | 42/653 [02:38<38:10,  3.75s/it][A

*************



  7%|▋         | 43/653 [02:42<38:07,  3.75s/it][A

*************



  7%|▋         | 44/653 [02:46<38:05,  3.75s/it][A

*************



  7%|▋         | 45/653 [02:50<38:01,  3.75s/it][A

*************



  7%|▋         | 46/653 [02:53<37:57,  3.75s/it][A

*************



  7%|▋         | 47/653 [02:57<37:52,  3.75s/it][A

*************



  7%|▋         | 48/653 [03:01<37:49,  3.75s/it][A

*************



  8%|▊         | 49/653 [03:05<37:46,  3.75s/it][A

*************



  8%|▊         | 50/653 [03:08<37:42,  3.75s/it][A

*************



  8%|▊         | 51/653 [03:12<37:39,  3.75s/it][A

*************



  8%|▊         | 52/653 [03:16<37:33,  3.75s/it][A

*************



  8%|▊         | 53/653 [03:20<37:31,  3.75s/it][A

*************



  8%|▊         | 54/653 [03:23<37:26,  3.75s/it][A

*************



  8%|▊         | 55/653 [03:27<37:22,  3.75s/it][A

*************



  9%|▊         | 56/653 [03:31<37:16,  3.75s/it][A

*************



  9%|▊         | 57/653 [03:35<37:12,  3.75s/it][A

*************



  9%|▉         | 58/653 [03:38<37:09,  3.75s/it][A

*************



  9%|▉         | 59/653 [03:42<37:05,  3.75s/it][A

*************



  9%|▉         | 60/653 [03:46<37:02,  3.75s/it][A

*************



  9%|▉         | 61/653 [03:50<36:59,  3.75s/it][A

*************



  9%|▉         | 62/653 [03:53<36:54,  3.75s/it][A

*************



 10%|▉         | 63/653 [03:57<36:52,  3.75s/it][A

*************



 10%|▉         | 64/653 [04:01<36:45,  3.75s/it][A

*************



 10%|▉         | 65/653 [04:05<36:41,  3.74s/it][A

*************



 10%|█         | 66/653 [04:08<36:39,  3.75s/it][A

*************



 10%|█         | 67/653 [04:12<36:35,  3.75s/it][A

*************



 10%|█         | 68/653 [04:16<36:27,  3.74s/it][A

*************



 11%|█         | 69/653 [04:19<36:25,  3.74s/it][A

*************



 11%|█         | 70/653 [04:23<36:25,  3.75s/it][A

*************



 11%|█         | 71/653 [04:27<36:18,  3.74s/it][A

*************



 11%|█         | 72/653 [04:31<36:15,  3.74s/it][A

*************



 11%|█         | 73/653 [04:34<36:11,  3.74s/it][A

*************



 11%|█▏        | 74/653 [04:38<36:07,  3.74s/it][A

*************



 11%|█▏        | 75/653 [04:42<36:06,  3.75s/it][A

*************



 12%|█▏        | 76/653 [04:46<36:02,  3.75s/it][A

*************



 12%|█▏        | 77/653 [04:49<35:56,  3.74s/it][A

*************



 12%|█▏        | 78/653 [04:53<35:52,  3.74s/it][A

*************



 12%|█▏        | 79/653 [04:57<35:47,  3.74s/it][A

*************



 12%|█▏        | 80/653 [05:01<35:45,  3.74s/it][A

*************



 12%|█▏        | 81/653 [05:04<35:40,  3.74s/it][A

*************



 13%|█▎        | 82/653 [05:08<35:35,  3.74s/it][A

*************



 13%|█▎        | 83/653 [05:12<35:31,  3.74s/it][A

*************



 13%|█▎        | 84/653 [05:16<35:30,  3.74s/it][A

*************



 13%|█▎        | 85/653 [05:19<35:27,  3.75s/it][A

*************



 13%|█▎        | 86/653 [05:23<35:22,  3.74s/it][A

*************



 13%|█▎        | 87/653 [05:27<35:22,  3.75s/it][A

*************



 13%|█▎        | 88/653 [05:31<35:15,  3.75s/it][A

*************



 14%|█▎        | 89/653 [05:34<35:12,  3.75s/it][A

*************



 14%|█▍        | 90/653 [05:38<35:10,  3.75s/it][A

*************



 14%|█▍        | 91/653 [05:42<35:04,  3.74s/it][A

*************



 14%|█▍        | 92/653 [05:46<35:01,  3.75s/it][A

*************



 14%|█▍        | 93/653 [05:49<34:58,  3.75s/it][A

*************



 14%|█▍        | 94/653 [05:53<34:53,  3.75s/it][A

*************



 15%|█▍        | 95/653 [05:57<34:52,  3.75s/it][A

*************



 15%|█▍        | 96/653 [06:01<34:48,  3.75s/it][A

*************



 15%|█▍        | 97/653 [06:04<34:42,  3.74s/it][A

*************



 15%|█▌        | 98/653 [06:08<34:41,  3.75s/it][A

*************



 15%|█▌        | 99/653 [06:12<34:37,  3.75s/it][A

*************



 15%|█▌        | 100/653 [06:16<34:29,  3.74s/it][A

*************



 15%|█▌        | 101/653 [06:19<34:27,  3.74s/it][A

*************



 16%|█▌        | 102/653 [06:23<34:24,  3.75s/it][A

*************



 16%|█▌        | 103/653 [06:27<34:19,  3.75s/it][A

*************



 16%|█▌        | 104/653 [06:31<34:15,  3.74s/it][A

*************



 16%|█▌        | 105/653 [06:34<34:13,  3.75s/it][A

*************



 16%|█▌        | 106/653 [06:38<34:09,  3.75s/it][A

*************



 16%|█▋        | 107/653 [06:42<34:04,  3.74s/it][A

*************



 17%|█▋        | 108/653 [06:46<33:59,  3.74s/it][A

*************



 17%|█▋        | 109/653 [06:49<33:56,  3.74s/it][A

*************



 17%|█▋        | 110/653 [06:53<34:01,  3.76s/it]
  0%|          | 0/50 [06:53<?, ?it/s]

*************





KeyboardInterrupt: ignored

In [None]:
print(mask_to_tensor(read_mask(train_loader_y[0][0])).shape)

In [None]:
from matplotlib import pyplot as plt

image_path = train_loader_y[145][1]
print(image_path)
mask = read_mask(image_path)
#Show the image with matplotlib
plt.imshow(mask, cmap='Greys_r')
plt.show()


In [None]:
a = torch.randn((3,3))
a = a.long()
print(a.type())