# preprocessing of images

In [None]:
import numpy as np
import pandas as pd
import datetime
import os
import shutil
import re
import math
import matplotlib.pyplot as plt
import skimage
from skimage.io import imread, imsave
from skimage.transform import resize
from skimage import img_as_float32, img_as_ubyte

import cv2

from albumentations import (
    Compose, HorizontalFlip, ShiftScaleRotate, ElasticTransform,
    RandomBrightness, RandomContrast, RandomGamma
)

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    print(e)
%matplotlib inline
root_folder = "/kaggle/input/paper-dataset/covid19-segmentation-paper-main/4_images"
masks_folder = os.path.join(root_folder, "masks")

img_size = 300

In [None]:
images = []
labels = []
manual_images = []
v7labs_images = []

for mask_path in os.listdir(masks_folder):
  mask = imread(os.path.join(masks_folder, mask_path))
  mask = np.float32(mask / 255)
  labels.append(mask)
  target, source, pathogen, pid, offset, _ = re.split("[_.]", mask_path)
  img_path = "%s_%s_%s_%s.png" % (source, pathogen, pid, offset)
  img = imread(os.path.join(root_folder, target, pathogen, img_path))
  img = img_as_float32(img)
  
  if source == "Cohen":
    v7labs_images.append(img)
  else:
    manual_images.append(img)
  
  images.append(img)
  
len(labels)

In [None]:
# Show some of our own CXR and masks
f = plt.figure()
f.add_subplot(3, 2, 1)
plt.imshow(images[50], cmap = "gray")
f.add_subplot(3, 2, 2)
plt.imshow(labels[50], cmap = "gray")
f.add_subplot(3, 2, 3)
plt.imshow(images[20], cmap = "gray")
f.add_subplot(3, 2, 4)
plt.imshow(labels[20], cmap = "gray")
f.add_subplot(3, 2, 5)
plt.imshow(images[30], cmap = "gray")
f.add_subplot(3, 2, 6)
plt.imshow(labels[30], cmap = "gray")

In [None]:
X = np.array(images).reshape((len(images),300,300,-1))
print(X.shape)

In [None]:
X = np.array(images).reshape((len(images), img_size,img_size, -1))
Y = np.array(labels).reshape((len(labels), img_size, img_size, -1))
X, Y = shuffle(X, Y, random_state = 1234)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.05, random_state = 1234)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size = 0.05, random_state = 1234)

print(X_train.shape)
print(X_val.shape)
print(X_test.shape)
print(X.shape)

# X_train = np.array(X_train).reshape(len(X_train),dim,dim,-1)
# y_train = np.array(y_train).reshape(len(y_train),dim,dim,-1)
# X_test = np.array(X_test).reshape(len(X_test),dim,dim,-1)
# y_test = np.array(y_test).reshape(len(y_test),dim,dim,-1)
# assert X_train.shape == y_train.shape
# assert X_test.shape == y_test.shape
# images = np.concatenate((X_train,X_test),axis=0)
# mask  = np.concatenate((y_train,y_test),axis=0)

In [None]:
shenzhen_test_ids = []
jsrt_test_ids = []
montgomery_test_ids = []
v7labs_test_ids = []
other_test_ids = []

nimages = X_test.shape[0]
for idx in range(nimages):
  test_image = X_test[idx,:,:,0]
  if any(np.array_equal(test_image, x) for x in shenzhen_images):
    shenzhen_test_ids.append(idx)
  elif any(np.array_equal(test_image, x) for x in montgomery_images):
    montgomery_test_ids.append(idx)
  elif any(np.array_equal(test_image, x) for x in jsrt_images):
    jsrt_test_ids.append(idx)
  elif any(np.array_equal(test_image, x) for x in v7labs_images):
    v7labs_test_ids.append(idx)
  else:
    other_test_ids.append(idx)

In [None]:
class AugmentationSequence(keras.utils.Sequence):
  def __init__(self, x_set, y_set, batch_size, augmentations):
    self.x, self.y = x_set, y_set
    self.batch_size = batch_size
    self.augment = augmentations

  def __len__(self):
    return int(np.ceil(len(self.x) / float(self.batch_size)))

  def __getitem__(self, idx):
    batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
    batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
    
    aug_x = np.zeros(batch_x.shape)
    aug_y = np.zeros(batch_y.shape)
    
    for idx in range(batch_x.shape[0]):
      aug = self.augment(image = batch_x[idx,:,:,:], mask = batch_y[idx,:,:,:])
      aug_x[idx,:,:,:] = aug["image"]
      aug_y[idx,:,:,:] = aug["mask"]
    
    return aug_x, aug_y

augment = Compose([
  HorizontalFlip(),
  ShiftScaleRotate(rotate_limit = 45, border_mode = cv2.BORDER_CONSTANT),
  ElasticTransform(border_mode = cv2.BORDER_CONSTANT),
  RandomBrightness(),
  RandomContrast(),
  RandomGamma()
])

batch_size = 16
train_generator = AugmentationSequence(X_train, Y_train, batch_size, augment)
steps_per_epoch = math.ceil(X_train.shape[0] / batch_size)


In [None]:
X_aug, Y_aug = train_generator.__getitem__(20)

f = plt.figure()
f.add_subplot(4, 2, 1)
plt.imshow(X_aug[0,:,:,0], cmap = "gray")
f.add_subplot(4, 2, 2)
plt.imshow(Y_aug[0,:,:,0], cmap = "gray")

f.add_subplot(4, 2, 3)
plt.imshow(X_aug[1,:,:,0], cmap = "gray")
f.add_subplot(4, 2, 4)
plt.imshow(Y_aug[1,:,:,0], cmap = "gray")

f.add_subplot(4, 2, 5)
plt.imshow(X_aug[2,:,:,0], cmap = "gray")
f.add_subplot(4, 2, 6)
plt.imshow(Y_aug[2,:,:,0], cmap = "gray")

f.add_subplot(4, 2, 7)
plt.imshow(X_aug[3,:,:,0], cmap = "gray")
f.add_subplot(4, 2, 8)
plt.imshow(Y_aug[3,:,:,0], cmap = "gray")

In [None]:
# LOSS Functions
def jaccard_distance_loss(y_true, y_pred, smooth = 100):
    intersection = keras.backend.sum(keras.backend.abs(y_true * y_pred), axis = -1)
    union = keras.backend.sum(keras.backend.abs(y_true) + keras.backend.abs(y_pred), axis = -1)
    jac = (intersection + smooth) / (union - intersection + smooth)
    loss = (1 - jac) * smooth
    return loss

def dice_coef(y_true, y_pred, smooth = 1):
    intersection = keras.backend.sum(keras.backend.abs(y_true * y_pred), axis = -1)
    union = keras.backend.sum(keras.backend.abs(y_true), -1) + keras.backend.sum(keras.backend.abs(y_pred), -1)
    return (2. * intersection + smooth) / (union + smooth)

In [None]:
import pdb
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
class AxialBlock_dynamic(nn.Module):
  expansion = 2

  def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                base_width=64, dilation=1, norm_layer=None, kernel_size=56):
      super(AxialBlock_dynamic, self).__init__()
      if norm_layer is None:
          norm_layer = nn.BatchNorm2d
      width = int(planes * (base_width / 64.))
      # Both self.conv2 and self.downsample layers downsample the input when stride != 1
      self.conv_down = conv1x1(inplanes, width)
      self.bn1 = norm_layer(width)
      self.hight_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size)
      self.width_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
      self.conv_up = conv1x1(width, planes * self.expansion)
      self.bn2 = norm_layer(planes * self.expansion)
      self.relu = nn.ReLU(inplace=True)
      self.downsample = downsample
      self.stride = stride

  def forward(self, x):
      identity = x

      out = self.conv_down(x)
      out = self.bn1(out)
      out = self.relu(out)

      out = self.hight_block(out)
      out = self.width_block(out)
      out = self.relu(out)

      out = self.conv_up(out)
      out = self.bn2(out)

      if self.downsample is not None:
          identity = self.downsample(x)
      #print(out.shape)
      #print(identity.shape)
      out += identity
      out = self.relu(out)

      return out
class ResAxialAttentionUNet(nn.Module):

  def __init__(self, block, layers, num_classes=2, zero_init_residual=True,
                groups=8, width_per_group=64, replace_stride_with_dilation=None,
                norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
      super(ResAxialAttentionUNet, self).__init__()
      if norm_layer is None:
          norm_layer = nn.BatchNorm2d
      self._norm_layer = norm_layer

      self.inplanes = int(64 * s)
      self.dilation = 1
      if replace_stride_with_dilation is None:
          replace_stride_with_dilation = [False, False, False]
      if len(replace_stride_with_dilation) != 3:
          raise ValueError("replace_stride_with_dilation should be None "
                            "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
      self.groups = groups
      self.base_width = width_per_group
      self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                              bias=False)
      self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
      self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
      self.bn1 = norm_layer(self.inplanes)
      self.bn2 = norm_layer(128)
      self.bn3 = norm_layer(self.inplanes)
      self.relu = nn.ReLU(inplace=True)
      # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
      self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
      self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
                                      dilate=replace_stride_with_dilation[0])
      self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
                                      dilate=replace_stride_with_dilation[1])
      self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
                                      dilate=replace_stride_with_dilation[2])
      
      # Decoder
      self.decoder1 = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
      self.decoder2 = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
      self.decoder3 = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
      self.decoder4 = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
      self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
      self.adjust   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
      self.soft     = nn.Softmax(dim=1)


  def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
      norm_layer = self._norm_layer
      downsample = None
      previous_dilation = self.dilation
      if dilate:
          self.dilation *= stride
          stride = 1
      if stride != 1 or self.inplanes != planes * block.expansion:
          downsample = nn.Sequential(
              conv1x1(self.inplanes, planes * block.expansion, stride),
              norm_layer(planes * block.expansion),
          )

      layers = []
      layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
                          base_width=self.base_width, dilation=previous_dilation, 
                          norm_layer=norm_layer, kernel_size=kernel_size))
      self.inplanes = planes * block.expansion
      if stride != 1:
          kernel_size = kernel_size // 2

      for _ in range(1, blocks):
          layers.append(block(self.inplanes, planes, groups=self.groups,
                              base_width=self.base_width, dilation=self.dilation,
                              norm_layer=norm_layer, kernel_size=kernel_size))

      return nn.Sequential(*layers)

  def _forward_impl(self, x):
      
      # AxialAttention Encoder
      # pdb.set_trace()
      x = self.conv1(x)
      x = self.bn1(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.bn3(x)
      x = self.relu(x)
      #print(x.shape)
      x1 = self.layer1(x)
      #print(x.shape)
      x2 = self.layer2(x1)
      # print(x2.shape)
      x3 = self.layer3(x2)
      # print(x3.shape)
      x4 = self.layer4(x3)

      x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x4)
      x = F.relu(F.interpolate(self.decoder2(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x3)
      x = F.relu(F.interpolate(self.decoder3(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x2)
      x = F.relu(F.interpolate(self.decoder4(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x1)
      x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
      x = self.adjust(F.relu(x))
      # pdb.set_trace()
      return x

  def forward(self, x):
      return self._forward_impl(x)

In [None]:
def gated(pretrained=False, **kwargs):
  model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, **kwargs)
  return model

In [None]:
net = gated()
x_test = torch.rand(1, 3, 128, 128)
net(x_test);

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Activation, Add, BatchNormalization, Conv2DTranspose, Softmax
from tensorflow.keras.models import Model

def conv1x1(in_channels, out_channels, stride=1):
    return Conv2D(out_channels, (1, 1), strides=(stride, stride), padding='same', use_bias=False)

class AxialAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, in_channels, out_channels, kernel_size=150, stride=1):
        super(AxialAttentionBlock, self).__init__()
        self.conv1 = Conv2D(out_channels, (1, kernel_size), strides=(1, stride), padding='same', use_bias=False)
        self.conv2 = Conv2D(out_channels, (kernel_size, 1), strides=(stride, 1), padding='same', use_bias=False)
        self.conv3 = Conv2D(out_channels, (1, 1), padding='same', use_bias=False)
        self.batch_norm1 = BatchNormalization()
        self.batch_norm2 = BatchNormalization()
        self.batch_norm3 = BatchNormalization()
        self.relu = Activation('relu')

    def call(self, x):
        x1 = self.relu(self.batch_norm1(self.conv1(x)))
        x2 = self.relu(self.batch_norm2(self.conv2(x)))
        x = self.relu(self.batch_norm3(self.conv3(x)))
        return Add()([x1, x2, x])

def axial_attention_unet(input_shape=(300, 300, 1), num_classes=2):
    inputs = Input(shape=input_shape)
    
    # Encoder
    x = Conv2D(64, kernel_size=7, strides=2, padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(128, kernel_size=3, strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(64, kernel_size=3, strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x1 = AxialAttentionBlock(64, 128, kernel_size=input_shape[0] // 2)(x)
    x2 = AxialAttentionBlock(128, 256, kernel_size=input_shape[0] // 2)(x1)
    x3 = AxialAttentionBlock(256, 512, kernel_size=input_shape[0] // 4)(x2)
    x4 = AxialAttentionBlock(512, 1024, kernel_size=input_shape[0] // 8)(x3)

    # Decoder
    x = Conv2DTranspose(1024, kernel_size=3, strides=2, padding='same')(x4)
    x = Add()([x, x4])
    x = Conv2DTranspose(1024, kernel_size=3, strides=1, padding='same')(x)
    x = Add()([x, x3])
    x = Conv2DTranspose(512, kernel_size=3, strides=1, padding='same')(x)
    x = Add()([x, x2])
    x = Conv2DTranspose(256, kernel_size=3, strides=1, padding='same')(x)
    x = Add()([x, x1])
    x = Conv2DTranspose(128, kernel_size=3, strides=1, padding='same')(x)
    x = Add()([x, x1])
    x = Conv2DTranspose(num_classes, kernel_size=1, strides=1, padding='same')(x)
    x = Softmax(axis=-1)(x)

    model = Model(inputs=inputs, outputs=x)
    return model

# Create the model
model = axial_attention_unet(input_shape=(300, 300, 1), num_classes=2)

# Print model summary
model.summary()


In [None]:
def unet_model():
  
  input_img = keras.layers.Input((img_size, img_size, 1), name = "img")
  
  # Contract #1
  c1 = keras.layers.Conv2D(16, (3, 3), kernel_initializer = "he_uniform", padding = "same")(input_img)
  c1 = keras.layers.BatchNormalization()(c1)
  c1 = keras.layers.Activation("relu")(c1)
  c1 = keras.layers.Dropout(0.1)(c1)
  c1 = keras.layers.Conv2D(16, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c1)
  c1 = keras.layers.BatchNormalization()(c1)
  c1 = keras.layers.Activation("relu")(c1)
  p1 = keras.layers.MaxPooling2D((2, 2))(c1)
  
  # Contract #2
  c2 = keras.layers.Conv2D(32, (3, 3), kernel_initializer = "he_uniform", padding = "same")(p1)
  c2 = keras.layers.BatchNormalization()(c2)
  c2 = keras.layers.Activation("relu")(c2)
  c2 = keras.layers.Dropout(0.2)(c2)
  c2 = keras.layers.Conv2D(32, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c2)
  c2 = keras.layers.BatchNormalization()(c2)
  c2 = keras.layers.Activation("relu")(c2)
  p2 = keras.layers.MaxPooling2D((2, 2))(c2)
  
  # Contract #3
  c3 = keras.layers.Conv2D(64, (3, 3), kernel_initializer = "he_uniform", padding = "same")(p2)
  c3 = keras.layers.BatchNormalization()(c3)
  c3 = keras.layers.Activation("relu")(c3)
  c3 = keras.layers.Dropout(0.3)(c3)
  c3 = keras.layers.Conv2D(64, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c3)
  c3 = keras.layers.BatchNormalization()(c3)
  c3 = keras.layers.Activation("relu")(c3)
  p3 = keras.layers.MaxPooling2D((2, 2))(c3)
  
  # Contract #4
  c4 = keras.layers.Conv2D(128, (3, 3), kernel_initializer = "he_uniform", padding = "same")(p3)
  c4 = keras.layers.BatchNormalization()(c4)
  c4 = keras.layers.Activation("relu")(c4)
  c4 = keras.layers.Dropout(0.4)(c4)
  c4 = keras.layers.Conv2D(128, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c4)
  c4 = keras.layers.BatchNormalization()(c4)
  c4 = keras.layers.Activation("relu")(c4)
  p4 = keras.layers.MaxPooling2D((2, 2))(c4)
  
  # Middle
  c5 = keras.layers.Conv2D(256, (3, 3), kernel_initializer = "he_uniform", padding = "same")(p4)
  c5 = keras.layers.BatchNormalization()(c5)
  c5 = keras.layers.Activation("relu")(c5)
  c5 = keras.layers.Dropout(0.5)(c5)
  c5 = keras.layers.Conv2D(256, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c5)
  c5 = keras.layers.BatchNormalization()(c5)
  c5 = keras.layers.Activation("relu")(c5)
  
  # Expand (upscale) #1
  u6 = keras.layers.Conv2DTranspose(128, (3, 3), strides = (2, 2), padding = "same")(c5)
  u6 = keras.layers.concatenate([u6, c4])
  c6 = keras.layers.Conv2D(128, (3, 3), kernel_initializer = "he_uniform", padding = "same")(u6)
  c6 = keras.layers.BatchNormalization()(c6)
  c6 = keras.layers.Activation("relu")(c6)
  c6 = keras.layers.Dropout(0.5)(c6)
  c6 = keras.layers.Conv2D(128, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c6)
  c6 = keras.layers.BatchNormalization()(c6)
  c6 = keras.layers.Activation("relu")(c6)
  
  # Expand (upscale) #2
  u7 = keras.layers.Conv2DTranspose(64, (3, 3), strides = (2, 2), padding = "same")(c6)
  u7 = keras.layers.concatenate([u7, c3])
  c7 = keras.layers.Conv2D(64, (3, 3), kernel_initializer = "he_uniform", padding = "same")(u7)
  c7 = keras.layers.BatchNormalization()(c7)
  c7 = keras.layers.Activation("relu")(c7)
  c7 = keras.layers.Dropout(0.5)(c7)
  c7 = keras.layers.Conv2D(64, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c7)
  c7 = keras.layers.BatchNormalization()(c7)
  c7 = keras.layers.Activation("relu")(c7)
  
  # Expand (upscale) #3
  u8 = keras.layers.Conv2DTranspose(32, (3, 3), strides = (2, 2), padding = "same")(c7)
  u8 = keras.layers.concatenate([u8, c2])
  c8 = keras.layers.Conv2D(32, (3, 3), kernel_initializer = "he_uniform", padding = "same")(u8)
  c8 = keras.layers.BatchNormalization()(c8)
  c8 = keras.layers.Activation("relu")(c8)
  c8 = keras.layers.Dropout(0.5)(c8)
  c8 = keras.layers.Conv2D(32, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c8)
  c8 = keras.layers.BatchNormalization()(c8)
  c8 = keras.layers.Activation("relu")(c8)
  
  # Expand (upscale) #4
  u9 = keras.layers.Conv2DTranspose(16, (3, 3), strides = (2, 2), padding = "same")(c8)
  u9 = keras.layers.concatenate([u9, c1])
  c9 = keras.layers.Conv2D(16, (3, 3), kernel_initializer = "he_uniform", padding = "same")(u9)
  c9 = keras.layers.BatchNormalization()(c9)
  c9 = keras.layers.Activation("relu")(c9)
  c9 = keras.layers.Dropout(0.5)(c9)
  c9 = keras.layers.Conv2D(16, (3, 3), kernel_initializer = "he_uniform", padding = "same")(c9)
  c9 = keras.layers.BatchNormalization()(c9)
  c9 = keras.layers.Activation("relu")(c9)
  
  output = keras.layers.Conv2D(1, (1, 1), activation = "sigmoid")(c9)
  model = keras.Model(inputs = [input_img], outputs = [output])
  return model

In [None]:
class ResAxialAttentionUNet(nn.Module):

  def __init__(self, block, layers, num_classes=2, zero_init_residual=True,
                groups=8, width_per_group=64, replace_stride_with_dilation=None,
                norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
      super(ResAxialAttentionUNet, self).__init__()
      if norm_layer is None:
          norm_layer = nn.BatchNorm2d
      self._norm_layer = norm_layer

      self.inplanes = int(64 * s)
      self.dilation = 1
      if replace_stride_with_dilation is None:
          replace_stride_with_dilation = [False, False, False]
      if len(replace_stride_with_dilation) != 3:
          raise ValueError("replace_stride_with_dilation should be None "
                            "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
      self.groups = groups
      self.base_width = width_per_group
      self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                              bias=False)
      self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
      self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
      self.bn1 = norm_layer(self.inplanes)
      self.bn2 = norm_layer(128)
      self.bn3 = norm_layer(self.inplanes)
      self.relu = nn.ReLU(inplace=True)
      # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
      self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
      self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
                                      dilate=replace_stride_with_dilation[0])
      self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
                                      dilate=replace_stride_with_dilation[1])
      self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
                                      dilate=replace_stride_with_dilation[2])
      
      # Decoder
      self.decoder1 = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
      self.decoder2 = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
      self.decoder3 = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
      self.decoder4 = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
      self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
      self.adjust   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
      self.soft     = nn.Softmax(dim=1)


  def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
      norm_layer = self._norm_layer
      downsample = None
      previous_dilation = self.dilation
      if dilate:
          self.dilation *= stride
          stride = 1
      if stride != 1 or self.inplanes != planes * block.expansion:
          downsample = nn.Sequential(
              conv1x1(self.inplanes, planes * block.expansion, stride),
              norm_layer(planes * block.expansion),
          )

      layers = []
      layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
                          base_width=self.base_width, dilation=previous_dilation, 
                          norm_layer=norm_layer, kernel_size=kernel_size))
      self.inplanes = planes * block.expansion
      if stride != 1:
          kernel_size = kernel_size // 2

      for _ in range(1, blocks):
          layers.append(block(self.inplanes, planes, groups=self.groups,
                              base_width=self.base_width, dilation=self.dilation,
                              norm_layer=norm_layer, kernel_size=kernel_size))

      return nn.Sequential(*layers)

  def _forward_impl(self, x):
      
      # AxialAttention Encoder
      # pdb.set_trace()
      x = self.conv1(x)
      x = self.bn1(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.bn3(x)
      x = self.relu(x)
      #print(x.shape)
      x1 = self.layer1(x)
      #print(x.shape)
      x2 = self.layer2(x1)
      # print(x2.shape)
      x3 = self.layer3(x2)
      # print(x3.shape)
      x4 = self.layer4(x3)

      x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x4)
      x = F.relu(F.interpolate(self.decoder2(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x3)
      x = F.relu(F.interpolate(self.decoder3(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x2)
      x = F.relu(F.interpolate(self.decoder4(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x1)
      x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
      x = self.adjust(F.relu(x))
      # pdb.set_trace()
      return x

  def forward(self, x):
      return self._forward_impl(x)

In [None]:
testing_files = set(os.listdir(image_path)) & set(os.listdir(mask_path))
training_files = check

def getData(X_shape, flag = "test"):
    im_array = []
    mask_array = []
   
    if flag == "test":
        for i in tqdm(testing_files): 
            im = cv2.resize(cv2.imread(os.path.join(image_path,i)),(X_shape,X_shape)),(224,224)[:,:,0]
            mask = cv2.resize(cv2.imread(os.path.join(mask_path,i)),(X_shape,X_shape)),(224,224)[:,:,0]
            
            im_array.append(im)
            mask_array.append(mask)
        
        return im_array,mask_array
    
    if flag == "train":
        for i in tqdm(training_files): 
            im = cv2.resize(cv2.imread(os.path.join(image_path,i.split("_mask")[0]+".png")),(X_shape,X_shape))[:,:,0]
            mask = cv2.resize(cv2.imread(os.path.join(mask_path,i+".png")),(X_shape,X_shape))[:,:,0]

            im_array.append(im)
            mask_array.append(mask)

        return im_array,mask_array

In [None]:
def plotMask(X,y):
    sample = []
    
    for i in range(6):
        left = X[i]
        right = y[i]
        combined = np.hstack((left,right))
        sample.append(combined)
        
        
    for i in range(0,6,3):

        plt.figure(figsize=(25,10))
        
        plt.subplot(2,3,1+i)
        plt.imshow(sample[i])
        
        plt.subplot(2,3,2+i)
        plt.imshow(sample[i+1])
        
        
        plt.subplot(2,3,3+i)
        plt.imshow(sample[i+2])
        
        plt.show()


In [None]:
dim = 224
X_train,y_train = getData(224,flag="train")
X_test, y_test = getData(224)

In [None]:
print("training set")
plotMask(X_train,y_train)
print("testing set")
plotMask(X_test,y_test)

In [None]:
X_train = np.array(X_train).reshape(len(X_train),dim,dim,-1)
y_train = np.array(y_train).reshape(len(y_train),dim,dim,-1)
X_test = np.array(X_test).reshape(len(X_test),dim,dim,-1)
y_test = np.array(y_test).reshape(len(y_test),dim,dim,-1)
assert X_train.shape == y_train.shape
assert X_test.shape == y_test.shape
images = np.concatenate((X_train,X_test),axis=0)
mask  = np.concatenate((y_train,y_test),axis=0)

In [None]:
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras import backend as keras
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, LearningRateScheduler


def dice_coef(y_true, y_pred):
    y_true_f = keras.flatten(y_true)
    y_pred_f = keras.flatten(y_pred)
    intersection = keras.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (keras.sum(y_true_f) + keras.sum(y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

def gated_axial_unet(input_size=(256, 256, 1)):
    inputs = Input(input_size)

    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    # Use axial_attention_block instead of Conv2DTranspose for the decoder part
    up6 = concatenate([axial_attention_block(conv5, 256, 2, 8), conv4], axis=3)
    up6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    up6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)

    up7 = concatenate([axial_attention_block(up6, 128, 2, 16), conv3], axis=3)
    up7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    up7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)

    up8 = concatenate([axial_attention_block(up7, 64, 2, 32), conv2], axis=3)
    up8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    up8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)

    up9 = concatenate([axial_attention_block(up8, 32, 2, 64), conv1], axis=3)
    up9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    up9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(up9)

    #Model(inputs=[inputs], outputs=[conv10])
    model = keras.Model(inputs = [input_img], outputs = [conv10])
    return model



In [None]:
model.summary()

In [None]:
reduce_learning_rate = keras.callbacks.ReduceLROnPlateau(
  monitor = "loss", 
  factor = 0.5, 
  patience = 3, 
  verbose = 1
)

checkpointer = keras.callbacks.ModelCheckpoint(
  "unet.h5", 
  verbose = 1, 
  save_best_only = True
)

strategy = tf.distribute.MirroredStrategy()

if (os.path.exists("unet.h5")):
  model = keras.models.load_model("unet.h5",
    custom_objects = {
      "jaccard_distance_loss": jaccard_distance_loss,
      "dice_coef": dice_coef
    }
  )
  
else:
  with strategy.scope():
    model = unet_model()
    adam_opt = keras.optimizers.Adam(learning_rate = 0.001)
    model.compile(optimizer = adam_opt, loss = jaccard_distance_loss, metrics = [dice_coef])
    
  fit = model.fit(train_generator, 
    steps_per_epoch = steps_per_epoch, 
    epochs = 100,
    validation_data = (X_val, Y_val),
    callbacks = [
      checkpointer,
      reduce_learning_rate
    ]
  )

# medt architecture

In [None]:
import pdb
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
#from .utils import *
import pdb
import matplotlib.pyplot as plt
import random
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class AxialAttention(nn.Module):
    def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
                 stride=1, bias=False, width=False):
        assert (in_planes % groups == 0) and (out_planes % groups == 0)
        super(AxialAttention, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.groups = groups
        self.group_planes = out_planes // groups
        self.kernel_size = kernel_size
        self.stride = stride
        self.bias = bias
        self.width = width

        # Multi-head self attention
        self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                           padding=0, bias=False)
        self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
        self.bn_similarity = nn.BatchNorm2d(groups * 3)

        self.bn_output = nn.BatchNorm1d(out_planes * 2)

        # Position embedding
        self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
        query_index = torch.arange(kernel_size).unsqueeze(0)
        key_index = torch.arange(kernel_size).unsqueeze(1)
        relative_index = key_index - query_index + kernel_size - 1
        self.register_buffer('flatten_index', relative_index.view(-1))
        if stride > 1:
            self.pooling = nn.AvgPool2d(stride, stride=stride)

        self.reset_parameters()

    def forward(self, x):
        # pdb.set_trace()
        if self.width:
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 3, 1, 2)  # N, W, C, H
        N, W, C, H = x.shape
        x = x.contiguous().view(N * W, C, H)

        # Transformations
        qkv = self.bn_qkv(self.qkv_transform(x))
        q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

        # Calculate position embedding
        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
        
        qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
        kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
        
        qk = torch.einsum('bgci, bgcj->bgij', q, k)
        
        stacked_similarity = torch.cat([qk, qr, kr], dim=1)
        stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
        #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
        # (N, groups, H, H, W)
        similarity = F.softmax(stacked_similarity, dim=3)
        sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
        sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
        stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
        output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)

        if self.width:
            output = output.permute(0, 2, 1, 3)
        else:
            output = output.permute(0, 2, 3, 1)

        if self.stride > 1:
            output = self.pooling(output)

        return output

    def reset_parameters(self):
        self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
        #nn.init.uniform_(self.relative, -0.1, 0.1)
        nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))

class AxialAttention_dynamic(nn.Module):
    def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
                 stride=1, bias=False, width=False):
        assert (in_planes % groups == 0) and (out_planes % groups == 0)
        super(AxialAttention_dynamic, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.groups = groups
        self.group_planes = out_planes // groups
        self.kernel_size = kernel_size
        self.stride = stride
        self.bias = bias
        self.width = width

        # Multi-head self attention
        self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                           padding=0, bias=False)
        self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
        self.bn_similarity = nn.BatchNorm2d(groups * 3)
        self.bn_output = nn.BatchNorm1d(out_planes * 2)

        # Priority on encoding

        ## Initial values 

        self.f_qr = nn.Parameter(torch.tensor(0.1),  requires_grad=False) 
        self.f_kr = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
        self.f_sve = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
        self.f_sv = nn.Parameter(torch.tensor(1.0),  requires_grad=False)


        # Position embedding
        self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
        query_index = torch.arange(kernel_size).unsqueeze(0)
        key_index = torch.arange(kernel_size).unsqueeze(1)
        relative_index = key_index - query_index + kernel_size - 1
        self.register_buffer('flatten_index', relative_index.view(-1))
        if stride > 1:
            self.pooling = nn.AvgPool2d(stride, stride=stride)

        self.reset_parameters()
        # self.print_para()

    def forward(self, x):
        if self.width:
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 3, 1, 2)  # N, W, C, H
        N, W, C, H = x.shape
        x = x.contiguous().view(N * W, C, H)

        # Transformations
        qkv = self.bn_qkv(self.qkv_transform(x))
        q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

        # Calculate position embedding
        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
        qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
        kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
        qk = torch.einsum('bgci, bgcj->bgij', q, k)


        # multiply by factors
        qr = torch.mul(qr, self.f_qr)
        kr = torch.mul(kr, self.f_kr)

        stacked_similarity = torch.cat([qk, qr, kr], dim=1)
        stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
        #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
        # (N, groups, H, H, W)
        similarity = F.softmax(stacked_similarity, dim=3)
        sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
        sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)

        # multiply by factors
        sv = torch.mul(sv, self.f_sv)
        sve = torch.mul(sve, self.f_sve)

        stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
        output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)

        if self.width:
            output = output.permute(0, 2, 1, 3)
        else:
            output = output.permute(0, 2, 3, 1)

        if self.stride > 1:
            output = self.pooling(output)

        return output
    def reset_parameters(self):
        self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
        #nn.init.uniform_(self.relative, -0.1, 0.1)
        nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))

class AxialAttention_wopos(nn.Module):
    def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
                 stride=1, bias=False, width=False):
        assert (in_planes % groups == 0) and (out_planes % groups == 0)
        super(AxialAttention_wopos, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.groups = groups
        self.group_planes = out_planes // groups
        self.kernel_size = kernel_size
        self.stride = stride
        self.bias = bias
        self.width = width

        # Multi-head self attention
        self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                           padding=0, bias=False)
        self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
        self.bn_similarity = nn.BatchNorm2d(groups )

        self.bn_output = nn.BatchNorm1d(out_planes * 1)

        if stride > 1:
            self.pooling = nn.AvgPool2d(stride, stride=stride)

        self.reset_parameters()

    def forward(self, x):
        if self.width:
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 3, 1, 2)  # N, W, C, H
        N, W, C, H = x.shape
        x = x.contiguous().view(N * W, C, H)

        # Transformations
        qkv = self.bn_qkv(self.qkv_transform(x))
        q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

        qk = torch.einsum('bgci, bgcj->bgij', q, k)

        stacked_similarity = self.bn_similarity(qk).reshape(N * W, 1, self.groups, H, H).sum(dim=1).contiguous()

        similarity = F.softmax(stacked_similarity, dim=3)
        sv = torch.einsum('bgij,bgcj->bgci', similarity, v)

        sv = sv.reshape(N*W,self.out_planes * 1, H).contiguous()
        output = self.bn_output(sv).reshape(N, W, self.out_planes, 1, H).sum(dim=-2).contiguous()


        if self.width:
            output = output.permute(0, 2, 1, 3)
        else:
            output = output.permute(0, 2, 3, 1)

        if self.stride > 1:
            output = self.pooling(output)

        return output

    def reset_parameters(self):
        self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
        #nn.init.uniform_(self.relative, -0.1, 0.1)
        # nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))

#end of attn definition

class AxialBlock(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, kernel_size=56):
        super(AxialBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.))
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv_down = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size)
        self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
        self.conv_up = conv1x1(width, planes * self.expansion)
        self.bn2 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv_down(x)
        out = self.bn1(out)
        out = self.relu(out)
        # print(out.shape)
        out = self.hight_block(out)
        out = self.width_block(out)
        out = self.relu(out)

        out = self.conv_up(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class AxialBlock_dynamic(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, kernel_size=56):
        super(AxialBlock_dynamic, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.))
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv_down = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.hight_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size)
        self.width_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
        self.conv_up = conv1x1(width, planes * self.expansion)
        self.bn2 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv_down(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.hight_block(out)
        out = self.width_block(out)
        out = self.relu(out)

        out = self.conv_up(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class AxialBlock_wopos(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, kernel_size=56):
        super(AxialBlock_wopos, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        # print(kernel_size)
        width = int(planes * (base_width / 64.))
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv_down = conv1x1(inplanes, width)
        self.conv1 = nn.Conv2d(width, width, kernel_size = 1)
        self.bn1 = norm_layer(width)
        self.hight_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size)
        self.width_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
        self.conv_up = conv1x1(width, planes * self.expansion)
        self.bn2 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        # pdb.set_trace()

        out = self.conv_down(x)
        out = self.bn1(out)
        out = self.relu(out)
        # print(out.shape)
        out = self.hight_block(out)
        out = self.width_block(out)

        out = self.relu(out)

        out = self.conv_up(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


#end of block definition


class ResAxialAttentionUNet(nn.Module):

    def __init__(self, block, layers, num_classes=2, zero_init_residual=True,
                 groups=8, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
        super(ResAxialAttentionUNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = int(64 * s)
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.bn2 = norm_layer(128)
        self.bn3 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
        self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
                                       dilate=replace_stride_with_dilation[2])
        
        # Decoder
        self.decoder1 = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
        self.decoder2 = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
        self.decoder3 = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
        self.decoder4 = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
        self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
        self.adjust   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
        self.soft     = nn.Softmax(dim=1)


    def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
                            base_width=self.base_width, dilation=previous_dilation, 
                            norm_layer=norm_layer, kernel_size=kernel_size))
        self.inplanes = planes * block.expansion
        if stride != 1:
            kernel_size = kernel_size // 2

        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer, kernel_size=kernel_size))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        
        # AxialAttention Encoder
        # pdb.set_trace()
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x1 = self.layer1(x)

        x2 = self.layer2(x1)
        # print(x2.shape)
        x3 = self.layer3(x2)
        # print(x3.shape)
        x4 = self.layer4(x3)

        x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
        x = torch.add(x, x4)
        x = F.relu(F.interpolate(self.decoder2(x) , scale_factor=(2,2), mode ='bilinear'))
        x = torch.add(x, x3)
        x = F.relu(F.interpolate(self.decoder3(x) , scale_factor=(2,2), mode ='bilinear'))
        x = torch.add(x, x2)
        x = F.relu(F.interpolate(self.decoder4(x) , scale_factor=(2,2), mode ='bilinear'))
        x = torch.add(x, x1)
        x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
        x = self.adjust(F.relu(x))
        # pdb.set_trace()
        return x

    def forward(self, x):
        return self._forward_impl(x)

class medt_net(nn.Module):

    def __init__(self, block, block_2, layers, num_classes=2, zero_init_residual=True,
                 groups=8, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
        super(medt_net, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = int(64 * s)
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.bn2 = norm_layer(128)
        self.bn3 = norm_layer(self.inplanes)
        # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
        self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
                                       dilate=replace_stride_with_dilation[0])
        # self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
        #                                dilate=replace_stride_with_dilation[1])
        # self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
        #                                dilate=replace_stride_with_dilation[2])
        
        # Decoder
        # self.decoder1 = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
        # self.decoder2 = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
        # self.decoder3 = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
        self.decoder4 = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
        self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
        self.adjust   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
        self.soft     = nn.Softmax(dim=1)


        self.conv1_p = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.conv2_p = nn.Conv2d(self.inplanes,128, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.conv3_p = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1,
                               bias=False)
        # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1_p = norm_layer(self.inplanes)
        self.bn2_p = norm_layer(128)
        self.bn3_p = norm_layer(self.inplanes)

        self.relu_p = nn.ReLU(inplace=True)

        img_size_p = img_size // 4

        self.layer1_p = self._make_layer(block_2, int(128 * s), layers[0], kernel_size= (img_size_p//2))
        self.layer2_p = self._make_layer(block_2, int(256 * s), layers[1], stride=2, kernel_size=(img_size_p//2),
                                       dilate=replace_stride_with_dilation[0])
        self.layer3_p = self._make_layer(block_2, int(512 * s), layers[2], stride=2, kernel_size=(img_size_p//4),
                                       dilate=replace_stride_with_dilation[1])
        self.layer4_p = self._make_layer(block_2, int(1024 * s), layers[3], stride=2, kernel_size=(img_size_p//8),
                                       dilate=replace_stride_with_dilation[2])
        
        # Decoder
        self.decoder1_p = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
        self.decoder2_p = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
        self.decoder3_p = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
        self.decoder4_p = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
        self.decoder5_p = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)

        self.decoderf = nn.Conv2d(int(128*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
        self.adjust_p   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
        self.soft_p     = nn.Softmax(dim=1)


    def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
                            base_width=self.base_width, dilation=previous_dilation, 
                            norm_layer=norm_layer, kernel_size=kernel_size))
        self.inplanes = planes * block.expansion
        if stride != 1:
            kernel_size = kernel_size // 2

        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer, kernel_size=kernel_size))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):

        xin = x.clone()
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        # x = F.max_pool2d(x,2,2)
        x = self.relu(x)
        
        # x = self.maxpool(x)
        # pdb.set_trace()
        x1 = self.layer1(x)
        # print(x1.shape)
        x2 = self.layer2(x1)
        # print(x2.shape)
        # x3 = self.layer3(x2)
        # # print(x3.shape)
        # x4 = self.layer4(x3)
        # # print(x4.shape)
        # x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
        # x = torch.add(x, x4)
        # x = F.relu(F.interpolate(self.decoder2(x4) , scale_factor=(2,2), mode ='bilinear'))
        # x = torch.add(x, x3)
        # x = F.relu(F.interpolate(self.decoder3(x3) , scale_factor=(2,2), mode ='bilinear'))
        # x = torch.add(x, x2)
        x = F.relu(F.interpolate(self.decoder4(x2) , scale_factor=(2,2), mode ='bilinear'))
        x = torch.add(x, x1)
        x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
        # print(x.shape)
        
        # end of full image training 

        # y_out = torch.ones((1,2,128,128))
        x_loc = x.clone()
        # x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
        #start 
        for i in range(0,4):
            for j in range(0,4):

                x_p = xin[:,:,32*i:32*(i+1),32*j:32*(j+1)]
                # begin patch wise
                x_p = self.conv1_p(x_p)
                x_p = self.bn1_p(x_p)
                # x = F.max_pool2d(x,2,2)
                x_p = self.relu(x_p)

                x_p = self.conv2_p(x_p)
                x_p = self.bn2_p(x_p)
                # x = F.max_pool2d(x,2,2)
                x_p = self.relu(x_p)
                x_p = self.conv3_p(x_p)
                x_p = self.bn3_p(x_p)
                # x = F.max_pool2d(x,2,2)
                x_p = self.relu(x_p)
                
                # x = self.maxpool(x)
                # pdb.set_trace()
                x1_p = self.layer1_p(x_p)
                # print(x1.shape)
                x2_p = self.layer2_p(x1_p)
                # print(x2.shape)
                x3_p = self.layer3_p(x2_p)
                # # print(x3.shape)
                x4_p = self.layer4_p(x3_p)
                
                x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2,2), mode ='bilinear'))
                x_p = torch.add(x_p, x4_p)
                x_p = F.relu(F.interpolate(self.decoder2_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
                x_p = torch.add(x_p, x3_p)
                x_p = F.relu(F.interpolate(self.decoder3_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
                x_p = torch.add(x_p, x2_p)
                x_p = F.relu(F.interpolate(self.decoder4_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
                x_p = torch.add(x_p, x1_p)
                x_p = F.relu(F.interpolate(self.decoder5_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
                
                x_loc[:,:,32*i:32*(i+1),32*j:32*(j+1)] = x_p

        x = torch.add(x,x_loc)
        x = F.relu(self.decoderf(x))
        
        x = self.adjust(F.relu(x))

        # pdb.set_trace()
        return x

    def forward(self, x):
        return self._forward_impl(x)


def axialunet(pretrained=False, **kwargs):
    model = ResAxialAttentionUNet(AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
    return model

def gated(pretrained=False, **kwargs):
    model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, **kwargs)
    return model

def MedT(pretrained=False, **kwargs):
    model = medt_net(AxialBlock_dynamic,AxialBlock_wopos, [1, 2, 4, 1], s= 0.125,  **kwargs)
    return model

def logo(pretrained=False, **kwargs):
    model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
    return model


# Medical Transformer Code 

In [None]:
import pdb
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
# use this seed function to make sure the result is reproducible
def reset_seed():
  torch.manual_seed(42)
  random.seed(42)
  torch.cuda.manual_seed(42)
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
from PIL import Image
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2

class QU_Dataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None, state=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index]).replace(".png",".png")
    image = np.array(Image.open(img_path).convert("RGB"))
    NEW_IMAGE_HEIGHT = 224
    NEW_IMAGE_WIDTH = 224
    image = cv2.resize(image, (NEW_IMAGE_WIDTH, NEW_IMAGE_HEIGHT))

    mask = np.array(Image.open(mask_path))
    mask[mask >= 1] = 1.0

    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]
    return image, mask

IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224

train_transform = A.Compose(
    [
        #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2(),
    ])
val_transform = A.Compose(
    [
        #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2(),
    ])

TRAIN_DIR = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\train_img'
TRAIN_MASK = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\train_masks'
TEST_DIR = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\val_img'
TEST_MASK = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\val_masks'

train_ds1 = QU_Dataset(TRAIN_DIR, TRAIN_MASK, transform=train_transform)
#train_ds, val_ds = torch.utils.data.random_split(train_ds1,[55,5], generator=torch.Generator().manual_seed(42))
test_ds = QU_Dataset(TEST_DIR, TEST_MASK, transform=val_transform)


In [None]:
def conv1x1(in_planes, out_planes, stride=1):
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class qkv_transform(nn.Conv1d):
  """Conv1d for qkv_transform"""
# define Medical Transformer architecture
class AxialAttention(nn.Module):

  def __init__(self,
               in_planes,
               out_planes,
               groups=8,
               kernel_size=56,
               stride=1,
               bias=False,
               width=False):
    """
    args :
      in_planes : d_q
      out_planes : d_out
      groups : number heads
      kernel_size : size of memory block
    """
    super().__init__()
    self.in_planes = in_planes
    self.out_planes = out_planes
    self.groups = groups
    self.group_planes = out_planes // groups
    self.kernel_size = kernel_size
    self.stride = stride
    self.bias = bias
    self.width = width

    # Multi-head self attention
    # d_q = out_planes // 2, d_out = out_planes (out_planes = 16) => number of
    self.qkv_transform = qkv_transform(in_planes, out_planes*2, kernel_size=1, stride=1,
                                       padding=0, bias=False)
    self.bn_qkv = nn.BatchNorm1d(out_planes * 2) # shape : q,k : (out_planes, d_q)
    self.bn_similarity = nn.BatchNorm2d(groups * 3)
    self.bn_output = nn.BatchNorm1d(out_planes * 2)

    # position embedding
    # (2 * kernel_size - 1) position : -(kernel_size-1),...,0,...,(kernel_size-1)
    # number of channels = channels of r_q + channels of r_k + channels of r_v
    # group_planes * 2   =  group_planes//2 * 2 + group_planes (dim value + dim query + dim key)
    # => dim of embedding position : group_planes * 2 * (kernel_size * 2 - 1)
    self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
    query_index = torch.arange(kernel_size).unsqueeze(0)
    key_index = torch.arange(kernel_size).unsqueeze(1)
    relative_index = key_index - query_index + kernel_size - 1
    self.register_buffer('flatten_index', relative_index.view(-1))
    if stride > 1:
      self.pooling = nn.AvgPool2d(stride, stride=stride)

    self.reset_paremeters()

  def forward(self, x):
    # axial attention width-axis
    if self.width:
      x = x.permute(0, 2, 1, 3) # N, H, C, W
    else:
      x = x.permute(0, 3, 1, 2) # N, W, C, H
    N, W, C, H = x.shape
    x = x.contiguous().view(N*W, C, H)
    # Transformations
    qkv = self.bn_qkv(self.qkv_transform(x))
    #print(x.shape)
    #print(self.qkv_transform(x).shape)
    q, k, v = torch.split(qkv.reshape(N*W, self.groups, self.group_planes * 2, H),
                          [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

    # Calculate position embedding
    # self.flatten_index : shape (kernel_size * kernel_size)
    all_embedding = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
    q_embedding, k_embedding, v_embedding = torch.split(all_embedding, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
    qqn = q_embedding[0].detach().cpu().numpy()
    #plt.imshow(qqn)
    kqn = k_embedding[0].detach().cpu().numpy()
    #plt.imshow(kqn)
    # q : shape (N*W, number heads, self.group_planes // 2 = d_q, H)
    # k : shape (N*W, number heads, self.group_planes // 2 = d_q, H)
    # v : shape (N*W, number heads, self.group_planes = 2 * d_q, H)
    # q_embedding : shape (self.group_planes // 2, kernel_size, kernel_size)
    # k_embedding : shape (self.group_planes // 2, kernel_size, kernel_size)
    # v_embedding : shape (self.group_planes, kernel_size, kernel_size)
    # qr : shape (N*W, number heads, kernel_size, kernel_size)
    # kr : shape (N*W, number heads, kernel_size, kernel_size)
    # why transpose(2, 3) -> because index of equal (2) Medical-Transform
    # qk : shape
    qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
    kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
    qk = torch.einsum('bgci,bgcj->bgij', q, k)
    # batchnorm each qr, qk, kr before sum
    stacked_similarity = torch.cat([qk, qr, kr], dim=1)
    stacked_similarity = self.bn_similarity(stacked_similarity).view(N*W, 3, self.groups, H, H).sum(dim=1)
    # stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
    # (N*W, groups, H, H)
    similarity = F.softmax(stacked_similarity, dim=3)
    sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
    sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
    stacked_output = torch.cat([sv, sve], dim=-1).view(N*W, self.out_planes*2, H)
    output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)

    if self.width:
      output = output.permute(0, 2, 1, 3)
    else:
      output = output.permute(0, 2, 3, 1)

    if self.stride > 1:
      output = self.pooling(output)
    return output

  def reset_paremeters(self):
    self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
    nn.init.normal_(self.relative, 0, math.sqrt(1. / self.group_planes))

class AxialAttention_dynamic(nn.Module):
  def __init__(self,
               in_planes,
               out_planes,
               groups=8,
               kernel_size=56,
               stride=1,
               bias=False,
               width=False):
      assert (in_planes % groups == 0) and (out_planes % groups == 0)
      super(AxialAttention_dynamic, self).__init__()
      self.in_planes = in_planes
      self.out_planes = out_planes
      self.groups = groups
      self.group_planes = out_planes // groups
      self.kernel_size = kernel_size
      self.stride = stride
      self.bias = bias
      self.width = width

      # Multi-head self attention
      self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                          padding=0, bias=False)
      self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
      self.bn_similarity = nn.BatchNorm2d(groups * 3)
      self.bn_output = nn.BatchNorm1d(out_planes * 2)

      # Priority on encoding

      ## Initial values

      self.f_qr = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
      self.f_kr = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
      self.f_sve = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
      self.f_sv = nn.Parameter(torch.tensor(1.0),  requires_grad=False)


      # Position embedding
      self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
      query_index = torch.arange(kernel_size).unsqueeze(0)
      key_index = torch.arange(kernel_size).unsqueeze(1)
      relative_index = key_index - query_index + kernel_size - 1
      self.register_buffer('flatten_index', relative_index.view(-1))
      if stride > 1:
          self.pooling = nn.AvgPool2d(stride, stride=stride)

      self.reset_parameters()
      # self.print_para()

  def forward(self, x):
      if self.width:
          x = x.permute(0, 2, 1, 3)
      else:
          x = x.permute(0, 3, 1, 2)  # N, W, C, H
      N, W, C, H = x.shape
      x = x.contiguous().view(N * W, C, H)

      # Transformations
      qkv = self.bn_qkv(self.qkv_transform(x))
      q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

      # Calculate position embedding
      all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
      q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
      qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
      kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
      qk = torch.einsum('bgci, bgcj->bgij', q, k)


      # multiply by factors
      qr = torch.mul(qr, self.f_qr)
      kr = torch.mul(kr, self.f_kr)

      stacked_similarity = torch.cat([qk, qr, kr], dim=1)
      stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
      #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
      # (N, groups, H, H, W)
      similarity = F.softmax(stacked_similarity, dim=3)
      sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
      sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)

      # multiply by factors
      sv = torch.mul(sv, self.f_sv)
      sve = torch.mul(sve, self.f_sve)

      stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
      output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)

      if self.width:
          output = output.permute(0, 2, 1, 3)
      else:
          output = output.permute(0, 2, 3, 1)

      if self.stride > 1:
          output = self.pooling(output)

      return output
  def reset_parameters(self):
      self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
      #nn.init.uniform_(self.relative, -0.1, 0.1)
      nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
class AxialAttention_wopos(nn.Module):
  def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
                stride=1, bias=False, width=False):
      assert (in_planes % groups == 0) and (out_planes % groups == 0)
      super(AxialAttention_wopos, self).__init__()
      self.in_planes = in_planes
      self.out_planes = out_planes
      self.groups = groups
      self.group_planes = out_planes // groups
      self.kernel_size = kernel_size
      self.stride = stride
      self.bias = bias
      self.width = width

      # Multi-head self attention
      self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                          padding=0, bias=False)
      self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
      self.bn_similarity = nn.BatchNorm2d(groups )

      self.bn_output = nn.BatchNorm1d(out_planes * 1)

      if stride > 1:
          self.pooling = nn.AvgPool2d(stride, stride=stride)

      self.reset_parameters()

  def forward(self, x):
      if self.width:
          x = x.permute(0, 2, 1, 3)
      else:
          x = x.permute(0, 3, 1, 2)  # N, W, C, H
      N, W, C, H = x.shape
      x = x.contiguous().view(N * W, C, H)

      # Transformations
      qkv = self.bn_qkv(self.qkv_transform(x))
      q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

      qk = torch.einsum('bgci, bgcj->bgij', q, k)

      stacked_similarity = self.bn_similarity(qk).reshape(N * W, 1, self.groups, H, H).sum(dim=1).contiguous()

      similarity = F.softmax(stacked_similarity, dim=3)
      sv = torch.einsum('bgij,bgcj->bgci', similarity, v)

      sv = sv.reshape(N*W,self.out_planes * 1, H).contiguous()
      output = self.bn_output(sv).reshape(N, W, self.out_planes, 1, H).sum(dim=-2).contiguous()


      if self.width:
          output = output.permute(0, 2, 1, 3)
      else:
          output = output.permute(0, 2, 3, 1)

      if self.stride > 1:
          output = self.pooling(output)

      return output

  def reset_parameters(self):
      self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
      #nn.init.uniform_(self.relative, -0.1, 0.1)
      # nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
class AxialBlock(nn.Module):
  expansion = 2

  def __init__(self,
               inplanes,
               planes,
               stride=1,
               downsample=None,
               groups=1,
               base_width=64,
               dilation=1,
               norm_layer=None,
               kernel_size=56):
    super().__init__()
    if norm_layer is None:
      norm_layer = nn.BatchNorm2d
    width = int(planes * (base_width / 64.))
    # Both self.conv2 and self.downsample layers downsample the input when stride != 1
    self.groups = groups
    self.conv_down = conv1x1(inplanes, width)
    self.bn1 = norm_layer(width)
    self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size)
    self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, width=True)
    self.conv_up = conv1x1(width, planes*self.expansion)
    self.bn2 = norm_layer(planes * self.expansion)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    identity = x

    out = self.conv_down(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.hight_block(out)
    out = self.width_block(out)
    out = self.relu(out)

    out = self.conv_up(out)
    out = self.bn2(out)

    if self.downsample is not None:
      identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out
class AxialBlock_dynamic(nn.Module):
  expansion = 2

  def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                base_width=64, dilation=1, norm_layer=None, kernel_size=56):
      super(AxialBlock_dynamic, self).__init__()
      if norm_layer is None:
          norm_layer = nn.BatchNorm2d
      width = int(planes * (base_width / 64.))
      # Both self.conv2 and self.downsample layers downsample the input when stride != 1
      self.conv_down = conv1x1(inplanes, width)
      self.bn1 = norm_layer(width)
      self.hight_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size)
      self.width_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
      self.conv_up = conv1x1(width, planes * self.expansion)
      self.bn2 = norm_layer(planes * self.expansion)
      self.relu = nn.ReLU(inplace=True)
      self.downsample = downsample
      self.stride = stride

  def forward(self, x):
      identity = x

      out = self.conv_down(x)
      out = self.bn1(out)
      out = self.relu(out)

      out = self.hight_block(out)
      out = self.width_block(out)
      out = self.relu(out)

      out = self.conv_up(out)
      out = self.bn2(out)

      if self.downsample is not None:
          identity = self.downsample(x)
      #print(out.shape)
      #print(identity.shape)
      out += identity
      out = self.relu(out)

      return out
class AxialBlock_wopos(nn.Module):
  expansion = 2

  def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                base_width=64, dilation=1, norm_layer=None, kernel_size=56):
      super(AxialBlock_wopos, self).__init__()
      if norm_layer is None:
          norm_layer = nn.BatchNorm2d
      # print(kernel_size)
      width = int(planes * (base_width / 64.))
      # Both self.conv2 and self.downsample layers downsample the input when stride != 1
      self.conv_down = conv1x1(inplanes, width)
      self.conv1 = nn.Conv2d(width, width, kernel_size = 1)
      self.bn1 = norm_layer(width)
      self.hight_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size)
      self.width_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
      self.conv_up = conv1x1(width, planes * self.expansion)
      self.bn2 = norm_layer(planes * self.expansion)
      self.relu = nn.ReLU(inplace=True)
      self.downsample = downsample
      self.stride = stride

  def forward(self, x):
      identity = x

      # pdb.set_trace()

      out = self.conv_down(x)
      out = self.bn1(out)
      out = self.relu(out)
      # print(out.shape)
      out = self.hight_block(out)
      out = self.width_block(out)

      out = self.relu(out)

      out = self.conv_up(out)
      out = self.bn2(out)

      if self.downsample is not None:
          identity = self.downsample(x)

      out += identity
      out = self.relu(out)

      return out
class ResAxialAttentionUNet(nn.Module):

  def __init__(self, block, layers, num_classes=2, zero_init_residual=True,
                groups=8, width_per_group=64, replace_stride_with_dilation=None,
                norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
      super(ResAxialAttentionUNet, self).__init__()
      if norm_layer is None:
          norm_layer = nn.BatchNorm2d
      self._norm_layer = norm_layer

      self.inplanes = int(64 * s)
      self.dilation = 1
      if replace_stride_with_dilation is None:
          replace_stride_with_dilation = [False, False, False]
      if len(replace_stride_with_dilation) != 3:
          raise ValueError("replace_stride_with_dilation should be None "
                            "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
      self.groups = groups
      self.base_width = width_per_group
      self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                              bias=False)
      self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
      self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
      self.bn1 = norm_layer(self.inplanes)
      self.bn2 = norm_layer(128)
      self.bn3 = norm_layer(self.inplanes)
      self.relu = nn.ReLU(inplace=True)
      # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
      self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
      self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
                                      dilate=replace_stride_with_dilation[0])
      self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
                                      dilate=replace_stride_with_dilation[1])
      self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
                                      dilate=replace_stride_with_dilation[2])

      # Decoder
      self.decoder1 = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
      self.decoder2 = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
      self.decoder3 = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
      self.decoder4 = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
      self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
      self.adjust   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
      self.soft     = nn.Softmax(dim=1)


  def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
      norm_layer = self._norm_layer
      downsample = None
      previous_dilation = self.dilation
      if dilate:
          self.dilation *= stride
          stride = 1
      if stride != 1 or self.inplanes != planes * block.expansion:
          downsample = nn.Sequential(
              conv1x1(self.inplanes, planes * block.expansion, stride),
              norm_layer(planes * block.expansion),
          )

      layers = []
      layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
                          base_width=self.base_width, dilation=previous_dilation,
                          norm_layer=norm_layer, kernel_size=kernel_size))
      self.inplanes = planes * block.expansion
      if stride != 1:
          kernel_size = kernel_size // 2

      for _ in range(1, blocks):
          layers.append(block(self.inplanes, planes, groups=self.groups,
                              base_width=self.base_width, dilation=self.dilation,
                              norm_layer=norm_layer, kernel_size=kernel_size))

      return nn.Sequential(*layers)

  def _forward_impl(self, x):

      # AxialAttention Encoder
      # pdb.set_trace()
      x = self.conv1(x)
      x = self.bn1(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.bn3(x)
      x = self.relu(x)
      #print(x.shape)
      x1 = self.layer1(x)
      #print(x.shape)
      x2 = self.layer2(x1)
      # print(x2.shape)
      x3 = self.layer3(x2)
      # print(x3.shape)
      x4 = self.layer4(x3)

      x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x4)
      x = F.relu(F.interpolate(self.decoder2(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x3)
      x = F.relu(F.interpolate(self.decoder3(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x2)
      x = F.relu(F.interpolate(self.decoder4(x) , scale_factor=(2,2), mode ='bilinear'))
      x = torch.add(x, x1)
      x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
      x = self.adjust(F.relu(x))
      # pdb.set_trace()
      return x

  def forward(self, x):
      return self._forward_impl(x)
class medt_net(nn.Module):

  def __init__(self, block, block_2, layers, num_classes=2, zero_init_residual=True,
              groups=8, width_per_group=64, replace_stride_with_dilation=None,
              norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
    super(medt_net, self).__init__()
    if norm_layer is None:
        norm_layer = nn.BatchNorm2d
    self._norm_layer = norm_layer

    self.inplanes = int(64 * s)
    self.dilation = 1
    if replace_stride_with_dilation is None:
        replace_stride_with_dilation = [False, False, False]
    if len(replace_stride_with_dilation) != 3:
        raise ValueError("replace_stride_with_dilation should be None "
                          "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
    self.groups = groups
    self.base_width = width_per_group
    self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                            bias=False)
    self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
    self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = norm_layer(self.inplanes)
    self.bn2 = norm_layer(128)
    self.bn3 = norm_layer(self.inplanes)
    # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = norm_layer(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
    # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
    self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
                                    dilate=replace_stride_with_dilation[0])
    # self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
    #                                dilate=replace_stride_with_dilation[1])
    # self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
    #                                dilate=replace_stride_with_dilation[2])

    # Decoder
    # self.decoder1 = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
    # self.decoder2 = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
    # self.decoder3 = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
    self.decoder4 = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
    self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
    self.adjust   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
    self.soft     = nn.Softmax(dim=1)


    self.conv1_p = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
                            bias=False)
    self.conv2_p = nn.Conv2d(self.inplanes,128, kernel_size=3, stride=1, padding=1,
                            bias=False)
    self.conv3_p = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1,
                            bias=False)
    # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1_p = norm_layer(self.inplanes)
    self.bn2_p = norm_layer(128)
    self.bn3_p = norm_layer(self.inplanes)

    self.relu_p = nn.ReLU(inplace=True)

    img_size_p = img_size // 4

    self.layer1_p = self._make_layer(block_2, int(128 * s), layers[0], kernel_size= (img_size_p//2))
    self.layer2_p = self._make_layer(block_2, int(256 * s), layers[1], stride=2, kernel_size=(img_size_p//2),
                                    dilate=replace_stride_with_dilation[0])
    self.layer3_p = self._make_layer(block_2, int(512 * s), layers[2], stride=2, kernel_size=(img_size_p//4),
                                    dilate=replace_stride_with_dilation[1])
    self.layer4_p = self._make_layer(block_2, int(1024 * s), layers[3], stride=2, kernel_size=(img_size_p//8),
                                    dilate=replace_stride_with_dilation[2])

    # Decoder
    self.decoder1_p = nn.Conv2d(int(1024 *2*s)      ,        int(1024*2*s), kernel_size=3, stride=2, padding=1)
    self.decoder2_p = nn.Conv2d(int(1024  *2*s)     , int(1024*s), kernel_size=3, stride=1, padding=1)
    self.decoder3_p = nn.Conv2d(int(1024*s),  int(512*s), kernel_size=3, stride=1, padding=1)
    self.decoder4_p = nn.Conv2d(int(512*s) ,  int(256*s), kernel_size=3, stride=1, padding=1)
    self.decoder5_p = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)

    self.decoderf = nn.Conv2d(int(128*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
    self.adjust_p   = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
    self.soft_p     = nn.Softmax(dim=1)


  def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
    norm_layer = self._norm_layer
    downsample = None
    previous_dilation = self.dilation
    if dilate:
        self.dilation *= stride
        stride = 1
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes * block.expansion, stride),
            norm_layer(planes * block.expansion),
        )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
                        base_width=self.base_width, dilation=previous_dilation,
                        norm_layer=norm_layer, kernel_size=kernel_size))
    self.inplanes = planes * block.expansion
    if stride != 1:
        kernel_size = kernel_size // 2

    for _ in range(1, blocks):
        layers.append(block(self.inplanes, planes, groups=self.groups,
                            base_width=self.base_width, dilation=self.dilation,
                            norm_layer=norm_layer, kernel_size=kernel_size))

    return nn.Sequential(*layers)

  def _forward_impl(self, x):

    xin = x.clone()
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)
    x = self.conv3(x)
    x = self.bn3(x)
    # x = F.max_pool2d(x,2,2)
    x = self.relu(x)

    # x = self.maxpool(x)
    # pdb.set_trace()
    x1 = self.layer1(x)
    # print(x1.shape)
    x2 = self.layer2(x1)
    # print(x2.shape)
    # x3 = self.layer3(x2)
    # # print(x3.shape)
    # x4 = self.layer4(x3)
    # # print(x4.shape)
    # x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
    # x = torch.add(x, x4)
    # x = F.relu(F.interpolate(self.decoder2(x4) , scale_factor=(2,2), mode ='bilinear'))
    # x = torch.add(x, x3)
    # x = F.relu(F.interpolate(self.decoder3(x3) , scale_factor=(2,2), mode ='bilinear'))
    # x = torch.add(x, x2)
    x = F.relu(F.interpolate(self.decoder4(x2) , scale_factor=(2,2), mode ='bilinear'))
    x = torch.add(x, x1)
    x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
    # print(x.shape)

    # end of full image training

    # y_out = torch.ones((1,2,128,128))
    x_loc = x.clone()
    # x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
    #start
    for i in range(0,4):
        for j in range(0,4):

            x_p = xin[:,:,32*i:32*(i+1),32*j:32*(j+1)]
            # begin patch wise
            x_p = self.conv1_p(x_p)
            x_p = self.bn1_p(x_p)
            # x = F.max_pool2d(x,2,2)
            x_p = self.relu(x_p)

            x_p = self.conv2_p(x_p)
            x_p = self.bn2_p(x_p)
            # x = F.max_pool2d(x,2,2)
            x_p = self.relu(x_p)
            x_p = self.conv3_p(x_p)
            x_p = self.bn3_p(x_p)
            # x = F.max_pool2d(x,2,2)
            x_p = self.relu(x_p)

            # x = self.maxpool(x)
            # pdb.set_trace()
            x1_p = self.layer1_p(x_p)
            # print(x1.shape)
            x2_p = self.layer2_p(x1_p)
            # print(x2.shape)
            x3_p = self.layer3_p(x2_p)
            # # print(x3.shape)
            x4_p = self.layer4_p(x3_p)

            x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2,2), mode ='bilinear'))
            x_p = torch.add(x_p, x4_p)
            x_p = F.relu(F.interpolate(self.decoder2_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
            x_p = torch.add(x_p, x3_p)
            x_p = F.relu(F.interpolate(self.decoder3_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
            x_p = torch.add(x_p, x2_p)
            x_p = F.relu(F.interpolate(self.decoder4_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
            x_p = torch.add(x_p, x1_p)
            x_p = F.relu(F.interpolate(self.decoder5_p(x_p) , scale_factor=(2,2), mode ='bilinear'))

            x_loc[:,:,32*i:32*(i+1),32*j:32*(j+1)] = x_p

    x = torch.add(x,x_loc)
    x = F.relu(self.decoderf(x))

    x = self.adjust(F.relu(x))

    # pdb.set_trace()
    return x

  def forward(self, x):
    return self._forward_impl(x)

# Mednet glass 
https://www.kaggle.com/code/umaiskhan19/mednet-glas-dataset/edit

In [None]:
import pdb
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
# use this seed function to make sure the result is reproducible
def reset_seed():
  torch.manual_seed(42)
  random.seed(42)
  torch.cuda.manual_seed(42)
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
from PIL import Image
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2

class QU_Dataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None, state=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index]).replace(".png",".png")
    image = np.array(Image.open(img_path).convert("RGB"))
    NEW_IMAGE_HEIGHT = 224
    NEW_IMAGE_WIDTH = 224
    image = cv2.resize(image, (NEW_IMAGE_WIDTH, NEW_IMAGE_HEIGHT))

    mask = np.array(Image.open(mask_path))
    mask[mask >= 1] = 1.0

    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]
    return image, mask

IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224

train_transform = A.Compose(
    [
        #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2(),
    ])
val_transform = A.Compose(
    [
        #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2(),
    ])

TRAIN_DIR = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\train_img'
TRAIN_MASK = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\train_masks'
TEST_DIR = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\val_img'
TEST_MASK = r'C:\Users\umaiskhan\Desktop\Project Files\NIH dataset\val_masks'

train_ds1 = QU_Dataset(TRAIN_DIR, TRAIN_MASK, transform=train_transform)
#train_ds, val_ds = torch.utils.data.random_split(train_ds1,[55,5], generator=torch.Generator().manual_seed(42))
test_ds = QU_Dataset(TEST_DIR, TEST_MASK, transform=val_transform)
