In [74]:
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import albumentations as A
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

In [8]:
!pip install torchsummary 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [120]:
import torch
#model = torch.jit.script(torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True))
mobileNetv2 = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
#mobileNetv2.eval()

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


In [121]:
encoder = [ torch.nn.Identity(),mobileNetv2.features[:2],  mobileNetv2.features[2:4],mobileNetv2.features[4:7],mobileNetv2.features[7:14],mobileNetv2.features[14:19]]

In [122]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [123]:
from torch._C import NoneType
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        #print("conv1 before:",x.shape)
        #x = self.conv1(x)
        #print("conv1 done:",x.shape)
        #x = self.conv2(x)
        #print("conv2 done:",x.shape)
        return self.conv(x)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p

class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out * s

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        #self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(in_c[0]+in_c[1], out_c)

    def forward(self, x, s=None):
        #print("before upsampling done:",x.shape,s.shape)
        x = self.up(x)
        #print("upsampling done:",x.shape,s.shape)
        #s = self.ag(x, s)
        #print(type(s))
        if s is not None:
          x = torch.cat([x, s], dim=1)
        #print("concatination done",x.shape)
        x = self.c1(x)
        #print("cov_block block done:",x.shape)
        return x

class mobileNetV2(nn.Module):
    def __init__(self):
        super().__init__()

        self.e1 = encoder[0]#encoder_block(3, 64) #   (3,16)
        self.e2 = encoder[1]#encoder_block(64, 128)  #(16,24)
        self.e3 = encoder[2]#encoder_block(128, 256)#(24,32)
        self.e4 = encoder[3]                        #(32,96)
        self.e5 = encoder[4]                        #(96,1280)
        
        self.b1 = encoder[5]  
        #self.b1 = conv_block(1280, 2560)

        self.d1 = decoder_block([1280, 96], 256)
        #self.d2 = decoder_block([1, 96], 96)
        self.d2 = decoder_block([256, 32], 128)
        self.d3 = decoder_block([128, 24], 64)
        self.d4 = decoder_block([64, 16], 32)
        #print("completed d4 decoder block")
        self.d5 = decoder_block([32, 0], 16)

        self.output = nn.Conv2d(16, 4, kernel_size=1, padding=0)
        

    def forward(self, x):
        s1 = self.e1(x)
        
        s2 = self.e2(s1)
        
        s3 = self.e3(s2)
        s4 = self.e4(s3)
        s5 = self.e5(s4)

        b1 = self.b1(s5)
        #print("b1:",b1.shape)
        #print("s5:",s5.shape)
        
        d1 = self.d1(b1, s5)
        
        d2 = self.d2(d1, s4)
        d3 = self.d3(d2, s3)
        d4 = self.d4(d3, s2)
        #print("d4 done",d4.shape)
        d5 = self.d5(d4, None)#torch.tensor(0))
        
        
        output = self.output(d5)
        #print("ouput:",output.shape)
        
        return output

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [124]:
import torchsummary as summary 
image = torch.rand((1,3,256,256))
model = mobileNetV2()
#summary(model,(3,256,256))
o = model(image)

In [8]:
!pip3 install unzip

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting unzip
  Downloading unzip-1.0.0.tar.gz (704 bytes)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: unzip
  Building wheel for unzip (setup.py) ... [?25l[?25hdone
  Created wheel for unzip: filename=unzip-1.0.0-py3-none-any.whl size=1279 sha256=d87cd4b324d3bc0f06996911bccea92bfe7dceff7f26d413f45f83077e48e7e4
  Stored in directory: /root/.cache/pip/wheels/80/dc/7a/f8af45bc239e7933509183f038ea8d46f3610aab82b35369f4
Successfully built unzip
Installing collected packages: unzip
Successfully installed unzip-1.0.0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [57]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.4


In [110]:
!unzip /content/drive/MyDrive/Colab/full_labels_for_deep_learning/masks_as_128x128_patches.zip
!mkdir mask
!cp -r  *.png ./mask/
!rm *.png 

!unzip /content/drive/MyDrive/Colab/full_labels_for_deep_learning/images_as_128x128_patches.zip
!mkdir image
!cp -r  *.png ./image/
!rm *.png 

unzip:  cannot find or open /content/drive/MyDrive/Colab/full_labels_for_deep_learning/masks_as_128x128_patches.zip, /content/drive/MyDrive/Colab/full_labels_for_deep_learning/masks_as_128x128_patches.zip.zip or /content/drive/MyDrive/Colab/full_labels_for_deep_learning/masks_as_128x128_patches.zip.ZIP.
mkdir: cannot create directory ‘mask’: File exists
cp: cannot stat '*.png': No such file or directory
rm: cannot remove '*.png': No such file or directory
unzip:  cannot find or open /content/drive/MyDrive/Colab/full_labels_for_deep_learning/images_as_128x128_patches.zip, /content/drive/MyDrive/Colab/full_labels_for_deep_learning/images_as_128x128_patches.zip.zip or /content/drive/MyDrive/Colab/full_labels_for_deep_learning/images_as_128x128_patches.zip.ZIP.
mkdir: cannot create directory ‘image’: File exists
cp: cannot stat '*.png': No such file or directory
rm: cannot remove '*.png': No such file or directory


In [59]:
import os
from google.colab import drive
drive.mount('/content/gdrive')
drive_path = "/gdrive/MyDrive/Colab Notebooks"

#DATA_DIR = './data/CamVid/'blank
DATA_DIR = '/content/data/CamVid'

# load repo with data if it is not exists
if not os.path.exists(DATA_DIR):
    print('Loading data...')
    os.system('git clone https://github.com/alexgkendall/SegNet-Tutorial ./data')
    print('Done!')

Mounted at /content/gdrive
Loading data...
Done!


In [60]:
DATA_DIR = '/content/data/CamVid'
x_train_dir = os.path.join(DATA_DIR, "train")
y_train_dir = os.path.join(DATA_DIR, "trainannot")

x_valid_dir = os.path.join(DATA_DIR, "val")
y_valid_dir = os.path.join(DATA_DIR, "valannot")

x_test_dir = os.path.join(DATA_DIR, "test")
y_test_dir = os.path.join(DATA_DIR, "testannot")

In [61]:
print('the number of image/label in the train: ',len(os.listdir(x_train_dir)))
print('the number of image/label in the validation: ',len(os.listdir(x_valid_dir)))
print('the number of image/label in the test: ',len(os.listdir(x_test_dir)))

the number of image/label in the train:  367
the number of image/label in the validation:  101
the number of image/label in the test:  233


In [125]:
def to_categorical(y, num_classes=None, dtype="float32"):
    """Converts a class vector (integers) to binary class matrix.
    E.g. for use with `categorical_crossentropy`.
    Args:
        y: Array-like with class values to be converted into a matrix
            (integers from 0 to `num_classes - 1`).
        num_classes: Total number of classes. If `None`, this would be inferred
          as `max(y) + 1`.
        dtype: The data type expected by the input. Default: `'float32'`.
    Returns:
        A binary matrix representation of the input as a NumPy array. The class
        axis is placed last.
    Example:
    >>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
    >>> print(a)
    [[1. 0. 0. 0.]
     [0. 1. 0. 0.]
     [0. 0. 1. 0.]
     [0. 0. 0. 1.]]
    >>> b = tf.constant([.9, .04, .03, .03,
    ...                  .3, .45, .15, .13,
    ...                  .04, .01, .94, .05,
    ...                  .12, .21, .5, .17],
    ...                 shape=[4, 4])
    >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
    >>> print(np.around(loss, 5))
    [0.10536 0.82807 0.1011  1.77196]
    >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
    >>> print(np.around(loss, 5))
    [0. 0. 0. 0.]
    """
    y = np.array(y, dtype="int")
    input_shape = y.shape

    # Shrink the last dimension if the shape is (..., 1).
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])

    y = y.reshape(-1)
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype=dtype)
    categorical[np.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical

In [126]:
import numpy as np


def preprocess_input(x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs):

    if input_space == "BGR":
        x = x[..., ::-1].copy()

    if input_range is not None:
        if x.max() > 1 and input_range[1] == 1:
            x = x / 255.0

    if mean is not None:
        mean = np.array(mean)
        x = x - mean

    if std is not None:
        std = np.array(std)
        x = x / std

    return x.astype('float32')

In [127]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import cv2
class McroscopingDataset(Dataset):
    
    def __init__(self,image_dir,mask_dir,transform=None):
      super().__init__()
      self.image_dir = image_dir
      self.mask_dir = mask_dir
      self.transform = transform
      self.images = os.listdir(image_dir)
      classes = ['1', '2', '3', '4']
      self.CLASSES = ['1', '2', '3', '4']
      # convert str names to class values on masks
      self.class_values = [self.CLASSES.index(cls.lower())+1 for cls in classes]

    def __len__(self):
      return len(self.images)

    def __getitem__(self,index):
          img_path = os.path.join(self.image_dir,self.images[index])
          mask_path = os.path.join(self.mask_dir,self.images[index]).replace('images_as','masks_as') 
          #mask_path = os.path.join(self.mask_dir,self.images[index])
          image = cv2.resize(cv2.imread(img_path),(256,256))
          mask = cv2.resize(cv2.imread(mask_path,0),(256,256))
          mask = mask - 1
          mask = to_categorical(mask, num_classes=4)
           # extract certain classes from mask (e.g. cars)
          #masks = [(mask == v) for v in self.class_values]
          #mask = np.stack(masks, axis=-1).astype('float')

          if self.transform is not None:
              image = preprocess_input(image,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],input_image='BGR',input_range=[0,1])
              image = self.transform(image)
              mask = self.transform(mask)

              
          return image,mask

In [128]:
from torchvision import transforms
preprocess = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize(mean=0, std=1),
])


In [None]:
image = cv2.imread('/content/image/images_as_128x128_patches-0.png')
image.shape

(128, 128, 3)

In [129]:
def get_images(image_dir,mask_dir,transform = None,batch_size=1,shuffle=True,pin_memory=True):
    data = McroscopingDataset(image_dir,mask_dir,transform = transform)
    #print(data)
    train_size = int(0.8 * data.__len__())
    test_size = data.__len__() - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
    train_batch = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
    test_batch = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
    return train_batch,test_batch

In [130]:
image_dir = '/content/image'
mask_dir = '/content/mask'


In [131]:
train_batch,test_batch= get_images(image_dir,mask_dir,transform = preprocess,batch_size=8,shuffle=True,pin_memory=True)
#train_batch,test_batch= get_images(x_train_dir,y_train_dir,transform = preprocess,batch_size=8,shuffle=True,pin_memory=True)

#val_batch = get_images(x_valid_dir,y_valid_dir,transform = preprocess,batch_size=8,shuffle=True,pin_memory=True)
#test_batch = get_images(x_test_dir,x_test_dir,transform = preprocess,batch_size=8,shuffle=True,pin_memory=True)

In [25]:
class encoding_block(nn.Module):
    def __init__(self,in_channels, out_channels):
        super(encoding_block,self).__init__()
        model = []
        model.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False))
        model.append(nn.BatchNorm2d(out_channels))
        model.append(nn.ReLU(inplace=True))
        model.append(nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False))
        model.append(nn.BatchNorm2d(out_channels))
        model.append(nn.ReLU(inplace=True))
        self.conv = nn.Sequential(*model)
    def forward(self, x):
        return self.conv(x)  

In [None]:
class unet_model(nn.Module):
    def __init__(self,out_channels=23,features=[64, 128, 256, 512]):
        super(unet_model,self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
        self.conv1 = encoding_block(3,features[0])
        self.conv2 = encoding_block(features[0],features[1])
        self.conv3 = encoding_block(features[1],features[2])
        self.conv4 = encoding_block(features[2],features[3])
        self.conv5 = encoding_block(features[3]*2,features[3])
        self.conv6 = encoding_block(features[3],features[2])
        self.conv7 = encoding_block(features[2],features[1])
        self.conv8 = encoding_block(features[1],features[0])        
        self.tconv1 = nn.ConvTranspose2d(features[-1]*2, features[-1], kernel_size=2, stride=2)
        self.tconv2 = nn.ConvTranspose2d(features[-1], features[-2], kernel_size=2, stride=2)
        self.tconv3 = nn.ConvTranspose2d(features[-2], features[-3], kernel_size=2, stride=2)
        self.tconv4 = nn.ConvTranspose2d(features[-3], features[-4], kernel_size=2, stride=2)        
        self.bottleneck = encoding_block(features[3],features[3]*2)
        self.final_layer = nn.Conv2d(features[0],out_channels,kernel_size=1)
    def forward(self,x):
        skip_connections = []
        x = self.conv1(x)
        skip_connections.append(x)
        x = self.pool(x)
        x = self.conv2(x)
        skip_connections.append(x)
        x = self.pool(x)
        x = self.conv3(x)
        skip_connections.append(x)
        x = self.pool(x)
        x = self.conv4(x)
        skip_connections.append(x)
        x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        x = self.tconv1(x)
        x = torch.cat((skip_connections[0], x), dim=1)
        x = self.conv5(x)
        x = self.tconv2(x)
        x = torch.cat((skip_connections[1], x), dim=1)
        x = self.conv6(x)
        x = self.tconv3(x)
        x = torch.cat((skip_connections[2], x), dim=1)
        x = self.conv7(x)        
        x = self.tconv4(x)
        x = torch.cat((skip_connections[3], x), dim=1)
        x = self.conv8(x)
        x = self.final_layer(x)
        #softmax = nn.Softmax(dim=1)
        #x = softmax(x)
        return x

In [None]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p

class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out * s

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        #self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(in_c[0]+out_c, out_c)

    def forward(self, x, s):
        x = self.up(x)
        #s = self.ag(x, s)
        x = torch.cat([x, s], axis=1)
        x = self.c1(x)
        return x

class attention_unet(nn.Module):
    def __init__(self):
        super().__init__()

        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)

        self.b1 = conv_block(256, 512)

        self.d1 = decoder_block([512, 256], 256)
        self.d2 = decoder_block([256, 128], 128)
        self.d3 = decoder_block([128, 64], 64)

        self.output = nn.Conv2d(64, 4, kernel_size=1, padding=0)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)

        b1 = self.b1(p3)

        d1 = self.d1(b1, s3)
        d2 = self.d2(d1, s2)
        d3 = self.d3(d2, s1)
        #print("d3:",d3.shape)
        output = self.output(d3)
        print("ouput:",output.shape)
        return output

In [None]:
image = torch.rand((1,3,256,256))
model = attention_unet()
model(image)

ouput: torch.Size([1, 4, 256, 256])


tensor([[[[-0.0768,  0.3738,  0.3264,  ...,  0.4023,  0.2890, -0.0585],
          [ 0.2856,  0.1785,  0.0740,  ...,  0.2228,  0.4839,  0.1202],
          [ 0.2566,  0.3757,  0.5598,  ...,  0.5833,  0.2586,  0.0506],
          ...,
          [-0.0509,  0.1006, -0.1939,  ...,  0.0501,  0.6223,  0.0922],
          [-0.1830, -0.2596, -0.0103,  ...,  0.0263,  0.1431, -0.1935],
          [ 0.2600,  0.0939,  0.2187,  ..., -0.1900,  0.0161, -0.2828]],

         [[ 0.4994,  0.1433,  0.3556,  ...,  0.5614,  0.3713,  0.1657],
          [ 0.2018, -0.2782, -0.4411,  ..., -0.0317, -0.1271,  0.0212],
          [-0.0582, -0.2623, -0.3229,  ..., -0.2204, -0.3812,  0.0982],
          ...,
          [ 0.3731,  0.0873, -0.6851,  ..., -0.4523, -0.3305, -0.3096],
          [ 0.2292,  0.0422, -0.3152,  ..., -0.5026, -0.4382, -0.1774],
          [-0.0382,  0.3076,  0.4489,  ..., -0.2053, -0.3266, -0.0319]],

         [[ 0.5566,  0.5385,  0.3309,  ...,  0.2497, -0.0563,  0.3072],
          [ 0.0614,  0.4427,  

In [None]:
import torch
import torch.nn as nn

class batchnorm_relu(nn.Module):
    def __init__(self, in_c):
        super().__init__()

        self.bn = nn.BatchNorm2d(in_c)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.bn(inputs)
        x = self.relu(x)
        return x

class residual_block(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super().__init__()

        """ Convolutional layer """
        self.b1 = batchnorm_relu(in_c)
        self.c1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride)
        self.b2 = batchnorm_relu(out_c)
        self.c2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, stride=1)

        """ Shortcut Connection (Identity Mapping) """
        self.s = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0, stride=stride)

    def forward(self, inputs):
        x = self.b1(inputs)
        x = self.c1(x)
        x = self.b2(x)
        x = self.c2(x)
        s = self.s(inputs)

        skip = x + s
        return skip

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.r = residual_block(in_c+out_c, out_c)

    def forward(self, inputs, skip):
        x = self.upsample(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.r(x)
        return x

class build_resunet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder 1 """
        self.c11 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.br1 = batchnorm_relu(64)
        self.c12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.c13 = nn.Conv2d(3, 64, kernel_size=1, padding=0)

        """ Encoder 2 and 3 """
        self.r2 = residual_block(64, 128, stride=2)
        self.r3 = residual_block(128, 256, stride=2)

        """ Bridge """
        self.r4 = residual_block(256, 512, stride=2)

        """ Decoder """
        self.d1 = decoder_block(512, 256)
        self.d2 = decoder_block(256, 128)
        self.d3 = decoder_block(128, 64)

        """ Output """
        self.output = nn.Conv2d(64, 4, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        """ Encoder 1 """
        x = self.c11(inputs)
        x = self.br1(x)
        x = self.c12(x)
        s = self.c13(inputs)
        skip1 = x + s

        """ Encoder 2 and 3 """
        skip2 = self.r2(skip1)
        skip3 = self.r3(skip2)

        """ Bridge """
        b = self.r4(skip3)

        """ Decoder """
        d1 = self.d1(b, skip3)
        d2 = self.d2(d1, skip2)
        d3 = self.d3(d2, skip1)

        """ output """
        output = self.output(d3)
        #output = self.sigmoid(output)

        return output


In [132]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [133]:
model = mobileNetV2().to(DEVICE)#unet_model(out_channels=4).to(DEVICE)#attention_unet().to(DEVICE) #unet_model(out_channels=4).to(DEVICE)

In [None]:
from torchsummary import summary
summary(model, (3, 256, 256))

In [134]:
LEARNING_RATE = 1e-4
num_epochs = 40

In [135]:
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

In [136]:
def diceLoss(pred, y_):
    eps = 1e-7
    beta = 1
    #print("pred: ",torch.unique(pred))
    #print("org:",torch.unique(y_))
    softmax = nn.Softmax(dim=1)
    pred = softmax(pred)
    #print("pred: ",torch.unique(pred))
    tp = torch.sum(pred*y_)
    fp = torch.sum(pred) - tp
    fn = torch.sum(y_) - tp

    score = (2*tp + eps)/(2*tp+beta**2*fn+fp+eps)
    return (1 - score)


In [137]:
from torchmetrics import JaccardIndex

def check_accuracy(loader, model):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    meanIOU = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            #print(y)
            #return y
            softmax = nn.Softmax(dim=1)
            preds = torch.argmax(softmax(model(x)),axis=1)
            y = torch.argmax(softmax(y),axis=1)
            #print(preds.shape)
            #print(y.shape)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            jaccard = JaccardIndex(task='multiclass',num_classes=4).to(DEVICE)
            IOU = jaccard(preds, y)
            meanIOU = meanIOU + IOU

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    print(f"Mean IOU: {meanIOU/len(loader)}")
    model.train()
    return (meanIOU/len(loader))

In [140]:
mean_IOU = 0
for epoch in range(num_epochs):
    print("epoch:", epoch)  
    loop = tqdm(enumerate(train_batch),total=len(train_batch))
    count = 0

    for batch_idx, (data, targets) in loop:
        count = count  + 1 
        data = data.to(DEVICE)
        targets = targets.to(DEVICE)
        #targets = targets.type(torch.long)
        #print("data:",data.shape)
        #print("target shape:", targets.shape)
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            #targets = targets.type(torch.long)
            #print("predictions:",predictions)
            #print("predictions:",predictions.shape)
            #print("target:",targets.shape)
            loss = diceLoss(predictions, targets)
            #loss = Variable(loss, requires_grad = True)
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # update tqdm loop
        loop.set_postfix(loss=loss.item())   
    IOU = check_accuracy(test_batch,model)
    if mean_IOU < IOU:
        mean_IOU = IOU
        #torch.save(model, 'best_epoch.pth')
        print("modelSaved")
    

epoch: 0


100%|██████████| 200/200 [00:20<00:00,  9.53it/s, loss=0.0252]


Got 20567209/20971520 with acc 98.07
Dice score: 1.9882233142852783
Mean IOU: 0.9194796681404114
modelSaved
epoch: 1


100%|██████████| 200/200 [00:19<00:00, 10.14it/s, loss=0.0265]


Got 20572248/20971520 with acc 98.10
Dice score: 1.9877277612686157
Mean IOU: 0.9215928316116333
modelSaved
epoch: 2


100%|██████████| 200/200 [00:19<00:00, 10.04it/s, loss=0.0217]


Got 20586934/20971520 with acc 98.17
Dice score: 1.9888336658477783
Mean IOU: 0.9225634932518005
modelSaved
epoch: 3


100%|██████████| 200/200 [00:20<00:00,  9.75it/s, loss=0.0201]


Got 20594539/20971520 with acc 98.20
Dice score: 1.9886738061904907
Mean IOU: 0.9253796935081482
modelSaved
epoch: 4


100%|██████████| 200/200 [00:19<00:00, 10.13it/s, loss=0.0232]


Got 20594623/20971520 with acc 98.20
Dice score: 1.9886187314987183
Mean IOU: 0.9234496355056763
epoch: 5


100%|██████████| 200/200 [00:20<00:00,  9.94it/s, loss=0.0223]


Got 20589722/20971520 with acc 98.18
Dice score: 1.9889326095581055
Mean IOU: 0.9244787096977234
epoch: 6


100%|██████████| 200/200 [00:20<00:00,  9.84it/s, loss=0.0232]


Got 20586764/20971520 with acc 98.17
Dice score: 1.987653374671936
Mean IOU: 0.9232603907585144
epoch: 7


100%|██████████| 200/200 [00:19<00:00, 10.20it/s, loss=0.0224]


Got 20605386/20971520 with acc 98.25
Dice score: 1.9879435300827026
Mean IOU: 0.9263378977775574
modelSaved
epoch: 8


100%|██████████| 200/200 [00:20<00:00,  9.88it/s, loss=0.0191]


Got 20608063/20971520 with acc 98.27
Dice score: 1.9890769720077515
Mean IOU: 0.9245777130126953
epoch: 9


100%|██████████| 200/200 [00:19<00:00, 10.29it/s, loss=0.0218]


Got 20615115/20971520 with acc 98.30
Dice score: 1.989118218421936
Mean IOU: 0.9295960664749146
modelSaved
epoch: 10


100%|██████████| 200/200 [00:19<00:00, 10.13it/s, loss=0.0222]


Got 20612146/20971520 with acc 98.29
Dice score: 1.9885462522506714
Mean IOU: 0.9271628260612488
epoch: 11


100%|██████████| 200/200 [00:20<00:00,  9.88it/s, loss=0.0188]


Got 20629673/20971520 with acc 98.37
Dice score: 1.989823579788208
Mean IOU: 0.9320128560066223
modelSaved
epoch: 12


100%|██████████| 200/200 [00:19<00:00, 10.22it/s, loss=0.0201]


Got 20616719/20971520 with acc 98.31
Dice score: 1.9880250692367554
Mean IOU: 0.9300462007522583
epoch: 13


100%|██████████| 200/200 [00:20<00:00,  9.81it/s, loss=0.0164]


Got 20621771/20971520 with acc 98.33
Dice score: 1.9883953332901
Mean IOU: 0.9233679175376892
epoch: 14


100%|██████████| 200/200 [00:20<00:00,  9.79it/s, loss=0.0168]


Got 20635874/20971520 with acc 98.40
Dice score: 1.9893203973770142
Mean IOU: 0.9318580031394958
epoch: 15


100%|██████████| 200/200 [00:20<00:00, 10.00it/s, loss=0.0233]


Got 20626068/20971520 with acc 98.35
Dice score: 1.9882458448410034
Mean IOU: 0.9301913380622864
epoch: 16


100%|██████████| 200/200 [00:20<00:00,  9.81it/s, loss=0.0179]


Got 20634331/20971520 with acc 98.39
Dice score: 1.9890104532241821
Mean IOU: 0.9327576756477356
modelSaved
epoch: 17


100%|██████████| 200/200 [00:19<00:00, 10.10it/s, loss=0.021]


Got 20609368/20971520 with acc 98.27
Dice score: 1.9862768650054932
Mean IOU: 0.9261856079101562
epoch: 18


100%|██████████| 200/200 [00:19<00:00, 10.12it/s, loss=0.016]


Got 20658183/20971520 with acc 98.51
Dice score: 1.9896490573883057
Mean IOU: 0.9362396597862244
modelSaved
epoch: 19


100%|██████████| 200/200 [00:20<00:00,  9.77it/s, loss=0.0161]


Got 20656286/20971520 with acc 98.50
Dice score: 1.9904016256332397
Mean IOU: 0.9348114132881165
epoch: 20


100%|██████████| 200/200 [00:19<00:00, 10.15it/s, loss=0.0166]


Got 20647284/20971520 with acc 98.45
Dice score: 1.98846435546875
Mean IOU: 0.9354033470153809
epoch: 21


100%|██████████| 200/200 [00:20<00:00,  9.86it/s, loss=0.02]


Got 20651870/20971520 with acc 98.48
Dice score: 1.9886261224746704
Mean IOU: 0.9352491497993469
epoch: 22


100%|██████████| 200/200 [00:20<00:00,  9.93it/s, loss=0.0174]


Got 20660412/20971520 with acc 98.52
Dice score: 1.9888525009155273
Mean IOU: 0.937405526638031
modelSaved
epoch: 23


100%|██████████| 200/200 [00:19<00:00, 10.24it/s, loss=0.0188]


Got 20656888/20971520 with acc 98.50
Dice score: 1.9889707565307617
Mean IOU: 0.9375545382499695
modelSaved
epoch: 24


100%|██████████| 200/200 [00:20<00:00,  9.76it/s, loss=0.0205]


Got 20666283/20971520 with acc 98.54
Dice score: 1.989636778831482
Mean IOU: 0.9398710131645203
modelSaved
epoch: 25


100%|██████████| 200/200 [00:19<00:00, 10.14it/s, loss=0.0128]


Got 20672250/20971520 with acc 98.57
Dice score: 1.9895180463790894
Mean IOU: 0.9401832818984985
modelSaved
epoch: 26


100%|██████████| 200/200 [00:19<00:00, 10.07it/s, loss=0.0154]


Got 20683075/20971520 with acc 98.62
Dice score: 1.9896513223648071
Mean IOU: 0.941541314125061
modelSaved
epoch: 27


100%|██████████| 200/200 [00:20<00:00,  9.79it/s, loss=0.0301]


Got 20605543/20971520 with acc 98.25
Dice score: 1.9884998798370361
Mean IOU: 0.9299700856208801
epoch: 28


100%|██████████| 200/200 [00:19<00:00, 10.15it/s, loss=0.0137]


Got 20679449/20971520 with acc 98.61
Dice score: 1.9895365238189697
Mean IOU: 0.9420532584190369
modelSaved
epoch: 29


100%|██████████| 200/200 [00:20<00:00,  9.92it/s, loss=0.0137]


Got 20679822/20971520 with acc 98.61
Dice score: 1.9905149936676025
Mean IOU: 0.9409249424934387
epoch: 30


100%|██████████| 200/200 [00:20<00:00,  9.99it/s, loss=0.0126]


Got 20694055/20971520 with acc 98.68
Dice score: 1.9903262853622437
Mean IOU: 0.9447709918022156
modelSaved
epoch: 31


100%|██████████| 200/200 [00:19<00:00, 10.17it/s, loss=0.015]


Got 20657953/20971520 with acc 98.50
Dice score: 1.9891138076782227
Mean IOU: 0.9392223358154297
epoch: 32


100%|██████████| 200/200 [00:20<00:00,  9.83it/s, loss=0.0147]


Got 20671318/20971520 with acc 98.57
Dice score: 1.9878818988800049
Mean IOU: 0.9393518567085266
epoch: 33


100%|██████████| 200/200 [00:19<00:00, 10.14it/s, loss=0.015]


Got 20710884/20971520 with acc 98.76
Dice score: 1.9903799295425415
Mean IOU: 0.9468361139297485
modelSaved
epoch: 34


100%|██████████| 200/200 [00:19<00:00, 10.02it/s, loss=0.0164]


Got 20683311/20971520 with acc 98.63
Dice score: 1.9885977506637573
Mean IOU: 0.941594123840332
epoch: 35


100%|██████████| 200/200 [00:20<00:00,  9.89it/s, loss=0.0132]


Got 20703298/20971520 with acc 98.72
Dice score: 1.990272879600525
Mean IOU: 0.9469969868659973
modelSaved
epoch: 36


100%|██████████| 200/200 [00:19<00:00, 10.16it/s, loss=0.0138]


Got 20717193/20971520 with acc 98.79
Dice score: 1.9900633096694946
Mean IOU: 0.9480493664741516
modelSaved
epoch: 37


100%|██████████| 200/200 [00:20<00:00,  9.87it/s, loss=0.0121]


Got 20715913/20971520 with acc 98.78
Dice score: 1.9896351099014282
Mean IOU: 0.9479711651802063
epoch: 38


100%|██████████| 200/200 [00:19<00:00, 10.01it/s, loss=0.0152]


Got 20718736/20971520 with acc 98.79
Dice score: 1.9915412664413452
Mean IOU: 0.9482424855232239
modelSaved
epoch: 39


100%|██████████| 200/200 [00:19<00:00, 10.21it/s, loss=0.0133]


Got 20703842/20971520 with acc 98.72
Dice score: 1.9888943433761597
Mean IOU: 0.9459635615348816


In [None]:
def _take_channels(*xs,ignore_channels=None):
  if ignore_channels is None:
    return xs
  else:
    channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels]  
    xs = [torch.index_select(x,dim=1,index=torch.tensor(channels).to(x)) for x in xs]
    return xs

def _threshold(x ,threshold=None):
  if threshold is not None:
    return (x>threshold).type(x.dtype)
  else:
    return x  


def IOU_(pr,gt,eps=1e-7,threshold=None,ignore_channels=None):
  pr = _threshold(pr,threshold=threshold)
  pr,gt = _take_channels(pr,gt,ignore_channels=ignore_channels)

  intersection = torch.sum(gt*pr)
  union = torch.sum(gt) + torch.sum(pr) - intersection +eps

  return (intersection+eps)/union

In [None]:
from torchmetrics import JaccardIndex

def check_accuracy_(loader, model):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    meanIOU = 0
    iou_c0 = 0
    iou_c1 = 0
    iou_c2 = 0
    iou_c3 = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            #print(y)
            #return y

            softmax = nn.Softmax(dim=1)
            #individual IOU
            pr = softmax(model(x))
            gt = x
            
            iou_c0 += IOU_(pr,gt,threshold=0.5,ignore_channels=[1,2,3])
            iou_c1 += IOU_(pr,gt,threshold=0.5,ignore_channels=[0,2,3])
            iou_c2 += IOU_(pr,gt,threshold=0.5,ignore_channels=[0,1,3])
            iou_c3 += IOU_(pr,gt,threshold=0.5,ignore_channels=[0,1,2])

            preds = torch.argmax(softmax(model(x)),axis=1)
            y = torch.argmax(softmax(y),axis=1)
            #print(preds.shape)
            #print(y.shape)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            jaccard = JaccardIndex(task='multiclass',num_classes=4).to(DEVICE)
            max = torch.max(y)
            IOU = jaccard(preds, y)
            meanIOU = meanIOU + IOU

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    print(f"Mean IOU: {meanIOU/len(loader)}")
    print(f"Mean IOU: {iou_c0/len(loader)}")
    print(f"Mean IOU: {iou_c1/len(loader)}")
    print(f"Mean IOU: {iou_c2/len(loader)}")
    print(f"Mean IOU: {iou_c3/len(loader)}")
    model.train()

In [None]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.4


In [None]:
check_accuracy(test_batch, model) #with attension unet

Got 5103744/5242880 with acc 97.35
Dice score: 1.9914945363998413
Mean IOU: 0.8921363949775696


tensor(0.8921, device='cuda:0')

In [None]:
check_accuracy(test_batch, model)

Got 5147024/5242880 with acc 98.17
Dice score: 1.9948867559432983
Mean IOU: 0.9246343970298767


tensor(0.9246, device='cuda:0')

In [None]:
check_accuracy(test_batch, model) #pytorch updated vanial unet 

Got 5144017/5242880 with acc 98.11
Dice score: 1.9938873052597046
Mean IOU: 0.9229429364204407


tensor(0.9229, device='cuda:0')

In [None]:
check_accuracy(test_batch, model) #pytorch updated resunet unet 

Got 5140636/5242880 with acc 98.05
Dice score: 1.9953851699829102
Mean IOU: 0.9211367964744568


tensor(0.9211, device='cuda:0')

In [None]:
check_accuracy(test_batch, model) #pytorch updated plan vanila unet with normlization 

Got 5146201/5242880 with acc 98.16
Dice score: 1.9953806400299072
Mean IOU: 0.9248954653739929


tensor(0.9249, device='cuda:0')

In [None]:
check_accuracy(test_batch, model) #pytorch updated plan vanila unet with normlization 

Got 5134853/5242880 with acc 97.94
Dice score: 1.9941147565841675
Mean IOU: 0.9131608009338379


tensor(0.9132, device='cuda:0')

In [139]:
check_accuracy(test_batch, model) 

Got 20529201/20971520 with acc 97.89
Dice score: 1.9880688190460205
Mean IOU: 0.9131507873535156


tensor(0.9132, device='cuda:0')

In [None]:
softmax = nn.Softmax(dim=1)
torch.argmax(softmax(y),axis=1).shape
#preds = torch.argmax(softmax(model(x)),axis=1)
#y = torch.argmax(softmax(y),axis=1)


In [None]:
for x,y in train_batch:
    x = x.to(DEVICE)
    fig , ax =  plt.subplots(3, 3, figsize=(18, 18))
    softmax = nn.Softmax(dim=1)
    preds = torch.argmax(softmax(model(x)),axis=1).to('cpu')
    y = torch.argmax(softmax(y),axis=1)
    img1 = np.transpose(np.array(x[0,:,:,:].to('cpu')),(1,2,0))
    preds1 = np.array(preds[0,:,:])
    mask1 = np.array(y[0,:,:])
    img2 = np.transpose(np.array(x[1,:,:,:].to('cpu')),(1,2,0))
    preds2 = np.array(preds[1,:,:])
    mask2 = np.array(y[1,:,:])
    img3 = np.transpose(np.array(x[2,:,:,:].to('cpu')),(1,2,0))
    preds3 = np.array(preds[2,:,:])
    mask3 = np.array(y[2,:,:])
    ax[0,0].set_title('Image')
    ax[0,1].set_title('Prediction')
    ax[0,2].set_title('Mask')
    ax[1,0].set_title('Image')
    ax[1,1].set_title('Prediction')
    ax[1,2].set_title('Mask')
    ax[2,0].set_title('Image')
    ax[2,1].set_title('Prediction')
    ax[2,2].set_title('Mask')
    ax[0][0].axis("off")
    ax[1][0].axis("off")
    ax[2][0].axis("off")
    ax[0][1].axis("off")
    ax[1][1].axis("off")
    ax[2][1].axis("off")
    ax[0][2].axis("off")
    ax[1][2].axis("off")
    ax[2][2].axis("off")
    ax[0][0].imshow(img1)
    ax[0][1].imshow(preds1)
    ax[0][2].imshow(mask1)
    ax[1][0].imshow(img2)
    ax[1][1].imshow(preds2)
    ax[1][2].imshow(mask2)
    ax[2][0].imshow(img3)
    ax[2][1].imshow(preds3)
    ax[2][2].imshow(mask3)   
    break