In [None]:
import pdb
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random

In [None]:
def reset_seed():
    torch.manual_seed(42)
    random.seed(42)
    torch.cuda.manual_seed(42)

# Custom Dataset

In [None]:
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
from PIL import Image
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

class QU_Dataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None, state=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index]).replace(".bmp","_anno.bmp")
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path))
    mask[mask >= 1] = 1.0

    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]
    return image, mask
class DataScienceBowl(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.image = np.load(image_dir)
    self.image = np.expand_dims(self.image, axis = -1)
    self.image = np.concatenate((self.image, self.image, self.image), axis = -1)
    self.mask = np.load(mask_dir).astype(np.double)
    self.transform = transform

  def __len__(self):
    return self.image.shape[0]

  def __getitem__(self, index):
    image = self.image[index]*255
    mask = self.mask[index].squeeze()

    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]
    return image[0,:,:].unsqueeze(0), mask
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=20, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            #mean = (0.485, 0.456, 0.406),
            #std = (0.229, 0.224, 0.225),
            mean = (0., 0., 0.),
            std = (1., 1., 1.),
            max_pixel_value = 255.0
        ),
        ToTensorV2(),
    ])
val_transform = A.Compose(
    [
        #A.RandomCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean = (0., 0., 0.),
            std = (1., 1., 1.),
            max_pixel_value = 255.0
        ),
        ToTensorV2(),
    ])

In [None]:
reset_seed()
BATCH_SIZE = 16
NUM_WORKERS = 0
train_loader = DataLoader(
    train_ds,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS,
    shuffle=True,
)
"""
val_loader = DataLoader(
    val_ds,
    
    batch_size=BATCH_SIZE,
    num_workers = NUM_WORKERS,
    pin_memory = PIN_MEMORY,
    shuffle=False
)
"""
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    num_workers = 0,
    shuffle=False,
)

#AMG Mixer

In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


##Axial attention

In [None]:
pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 KB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.11.1 timm-0.6.12


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

In [None]:
class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)#.flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops

In [None]:
from einops import rearrange, reduce
import math
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data.dataset import Dataset
import matplotlib.pyplot as plt
import pandas as pd
import torchvision
# from src.modules.eprop import eprop
import torch.utils.model_zoo as model_zoo
# from scripts.SEAM.network import resnet38_SEAM, resnet38_aff
from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np

# class SeparableConv2d(nn.Module):
#     def __init__(self, inplanes, planes, kernel_size=3, stride=2, dilation=2, bias=False):
#         super().__init__()
#         self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding=dilation, dilation=dilation,bias=bias)
#         self.bn = nn.BatchNorm2d(inplanes)
#         self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
#         self.silu = nn.SiLU(inplace=True)
#     def forward(self, x):
#         x = self.conv1(x)
#         #x = self.silu(x)
#         x = self.bn(x)
#         x = self.pointwise(x)
#         return x

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class qkv_transform(nn.Conv1d):
  """Conv1d for qkv_transform"""

class AxialAttention(nn.Module):
    
    def __init__(self,
               in_planes,
               out_planes,
               groups=8,
               kernel_size=56,
               stride=1,
               bias=False,
               width=False,
               sep=False,
               ):
        assert (in_planes % groups == 0) and (out_planes % groups == 0)
        super().__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.groups = groups
        self.group_planes = out_planes // groups
        self.kernel_size = kernel_size
        self.stride = stride
        self.bias = bias
        self.width = width

        # Multi-head self attention
        self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                           padding=0, bias = False) # because Batchnorm
        self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
        self.bn_similarity = nn.BatchNorm2d(groups * 3)

        self.bn_output = nn.BatchNorm1d(out_planes * 2)

        # Position embedding
        self.relative = nn.Parameter(torch.rand(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
        query_index = torch.arange(kernel_size).unsqueeze(0)
        key_index = torch.arange(kernel_size).unsqueeze(1)
        relative_index = key_index - query_index + kernel_size - 1
        self.register_buffer('flatten_index', relative_index.view(-1))
        if stride > 1:
            if sep :
                self.pooling = SeparableConv2d(out_planes, out_planes)
            else:
                self.pooling = nn.AvgPool2d(stride, stride=stride)

        self.reset_parameters()

    def forward(self, x):
        if self.width:
            x = x.permute(0, 2, 1, 3)
        else :
            x = x.permute(0, 3, 1, 2)
        N, W, C, H = x.shape
        x = x.contiguous().view(N*W, C, H)

        # Transformations
        qkv = self.bn_qkv(self.qkv_transform(x))
        q, k, v = torch.split(qkv.reshape(N*W, self.groups, self.group_planes*2, H), 
                              [self.group_planes // 2,self.group_planes // 2, self.group_planes], dim=2)

        # Calculate position embedding
        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes*2, self.kernel_size, self.kernel_size )
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2,self.group_planes // 2, self.group_planes],
                                                            dim = 0)

        qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
        kr = torch.einsum('bgcj,cij->bgij',k, k_embedding)
        qk = torch.einsum('bgci, bgcj -> bgij',q , k)

        stacked_similarity = torch.cat([qk, qr, kr], dim=1)
        # bn_similarity chanels dim = self.group * 3
        stacked_similarity = self.bn_similarity(stacked_similarity).view(N*W, 3, self.groups, H, H).sum(dim=1)
        similarity = F.softmax(stacked_similarity, dim=3)
        sv = torch.einsum("bgij,bgci->bgci", similarity, v)
        sve = torch.einsum("bgij,cij->bgci", similarity, v_embedding)
        stacked_output = torch.cat([sv, sve], dim=-1).view(N*W, self.out_planes*2, H)
        output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)

        if self.width:
            output = output.permute(0, 2, 1, 3)
        else:
            output = output.permute(0, 2, 3, 1)
        if self.stride > 1:
            output = self.pooling(output)
        return output
  
    def reset_parameters(self):
        self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
        nn.init.normal_(self.relative, 0, math.sqrt(1. / self.group_planes))

class AxialTokenMix(nn.Module):
    # expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=8,
                 base_width=64, dilation=1, norm_layer=None, kernel_size=56):
        super(AxialTokenMix, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.))
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv_down = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        # height and weight axial attention
        self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size)
        self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
        #squeeze and excitation

        # self.prj1 = nn.Linear(width,width//2)
        # self.prj2 = nn.Linear(width//2, width)
        # self.prjout = nn.Linear(width, width)

        #MBConv

        self.mbconv = MBConv(width, width, 2, 3, 1, True, 0.25)
        
        # self.conv_up = conv1x1(width, planes * self.expansion)
        self.bn2 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv_down(x)
        #print(out.shape)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.hight_block(out)
        # print(out.shape)
        out = self.width_block(out)
        # print(out.shape)
        out = self.relu(out)
        #squeeze and excitation block

        # k = reduce(out, 'b c w h -> b 1 1 c', 'mean')
        # # print(k.shape)
        # a = self.relu(self.prj1(out.permute(0,2,3,1)))
        # # print(a.shape)
        # a = self.relu(self.prj2(a))
        # # print(a.shape)
        # a = self.sigmoid(a)
        # # print(a.shape)
        # x = self.prjout(k*a) + x.permute(0,2,3,1)
        # print(x.shape)

        out = self.mbconv(out)

        # out = self.conv_up(x.permute(0,3,1,2))
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        # print(out.shape)

        return out


##ConvBlock

In [None]:
from torch.nn.modules.pooling import MaxPool2d
class ConvBlock(nn.Module):
  def __init__(self, int , out):
    super().__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(int, out, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(out),
        nn.ReLU(),
        #nn.MaxPool2d(2)
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(out, out, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(out),
        nn.ReLU(),
        #nn.MaxPool2d(2)
    )
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    return x

In [None]:
from torch.nn.modules.pooling import MaxPool2d
class ConvBlockpool(nn.Module):
  def __init__(self, int , out):
    super().__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(int, out, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(out),
        nn.ReLU(),
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(out, out, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(out),
        nn.ReLU(),
        nn.MaxPool2d(2,2)
    )
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    return x

In [None]:
class feature_conv(nn.Module):
  def __init__(self, inchannels ):
    super().__init__()
    self.conv1 = ConvBlockpool(inchannels , 64)
    self.conv2 = ConvBlockpool(64, 128)
    self.conv3 = ConvBlockpool(128, 256)
  def forward(self, x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)
    x3= self.conv3(x2)
    return x1, x2 , x3 

##PASPP module

In [None]:
import torch.nn as nn
import torchvision
import torch
from skimage import morphology as morph
#from src.modules.eprop import eprop
import torch.utils.model_zoo as model_zoo
#from scripts.SEAM.network import resnet38_SEAM, resnet38_aff
import numpy as np
from torch import optim
import torch.nn.functional as F

import torch
import torch.nn.functional as F
import numpy as np
from skimage.morphology import watershed
from skimage.segmentation import find_boundaries
from scipy import ndimage

class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                      stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = BatchNorm(planes)
        self.silu = nn.SiLU(inplace=True)
        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.silu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes, outplanes, output_stride, BatchNorm):
        super().__init__()
        if output_stride == 4:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 4, 6, 10]
        elif output_stride == 2:
            dilations = [1, 12, 24, 36]
        else:
            raise NotImplementedError

        #self.aspp1 = _ASPPModule(inplanes, outplanes, 1, padding=0,dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp2 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.aspp3 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.aspp4 = _ASPPModule(inplanes, outplanes, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
                                             nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False),
                                             #BatchNorm(outplanes),
                                             nn.SiLU(inplace=True))
        self.conv1 = nn.Conv2d(outplanes*4, outplanes, 1, bias=False)
        self.bn1 = BatchNorm(outplanes)
        self.silu = nn.SiLU(inplace=True)
        self.dropout = nn.Dropout(0.0)
        self._init_weight()

    def forward(self, x):
        #x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.silu(x)

        return self.dropout(x)
  
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class PASPP(nn.Module):
    def __init__(self, inplanes, outplanes, output_stride=4, BatchNorm=nn.BatchNorm2d):
        super().__init__()
        if output_stride == 4:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 4, 6, 10]
        elif output_stride == 2:
            dilations = [1, 12, 24, 36]
        elif output_stride == 16:
            dilations = [1, 2, 3, 4]
        elif output_stride == 1:
            dilations = [1, 16, 32, 48]
        else:
            raise NotImplementedError
        self._norm_layer = BatchNorm
        self.silu = nn.SiLU(inplace=True)
        self.conv1 = self._make_layer(inplanes, inplanes // 4)
        self.conv2 = self._make_layer(inplanes, inplanes // 4)
        self.conv3 = self._make_layer(inplanes, inplanes // 4)
        self.conv4 = self._make_layer(inplanes, inplanes // 4)
        self.atrous_conv1 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[0], padding=dilations[0])
        self.atrous_conv2 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[1], padding=dilations[1])
        self.atrous_conv3 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[2], padding=dilations[2])
        self.atrous_conv4 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[3], padding=dilations[3])
        self.conv5 = self._make_layer(inplanes // 2, inplanes // 2)
        self.conv6 = self._make_layer(inplanes // 2, inplanes // 2)
        self.convout = self._make_layer(inplanes, inplanes)
    
    def _make_layer(self, inplanes, outplanes):
        layer = []
        layer.append(nn.Conv2d(inplanes, outplanes, kernel_size = 1))
        layer.append(self._norm_layer(outplanes))
        layer.append(self.silu)
        return nn.Sequential(*layer)
    
    def forward(self, X):
        x1 = self.conv1(X)
        x2 = self.conv2(X)
        x3 = self.conv3(X)
        x4 = self.conv4(X)
        
        x12 = torch.add(x1, x2)
        x34 = torch.add(x3, x4)
        
        x1 = torch.add(self.atrous_conv1(x1),x12)
        x2 = torch.add(self.atrous_conv2(x2),x12)
        x3 = torch.add(self.atrous_conv3(x3),x34)
        x4 = torch.add(self.atrous_conv4(x4),x34)
        
        x12 = torch.cat([x1, x2], dim = 1)
        x34 = torch.cat([x3, x4], dim = 1)
        
        x12 = self.conv5(x12)
        x34 = self.conv5(x34)
        x = torch.cat([x12, x34], dim=1)
        x = self.convout(x)
        return x 

##Skip Connection

In [None]:
import torch
import torch.nn as nn

'''pixel-level module'''


class PixLevelModule(nn.Module):
    def __init__(self, in_channels):
        super(PixLevelModule, self).__init__()
        self.middle_layer_size_ratio = 2 
        self.conv_avg = nn.Conv2d(in_channels, out_channels=in_channels, kernel_size=1, bias=False)
        self.relu_avg = nn.ReLU(inplace=True)
        self.conv_max = nn.Conv2d(in_channels, out_channels=in_channels, kernel_size=1, bias=False)
        self.relu_max = nn.ReLU(inplace=True)
        self.bottleneck = nn.Sequential(
            nn.Linear(3, 3 * self.middle_layer_size_ratio),  # 2, 2*self.
            nn.ReLU(inplace=True),
            nn.Linear(3 * self.middle_layer_size_ratio, 1)
        )
        self.conv_sig = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )
        # self.up = nn.Upsample(scale_factor=scale_factor)

    '''forward'''

    def forward(self, x, patch):
        x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1 = patch, p2 = patch)
        x_avg = self.conv_avg(x)  
        x_avg = self.relu_avg(x_avg) 
        x_avg = torch.mean(x_avg, dim=1)
        x_avg = x_avg.unsqueeze(dim=1)
        x_max = self.conv_max(x)
        x_max = self.relu_max(x_max)
        x_max = torch.max(x_max, dim=1).values
        x_max = x_max.unsqueeze(dim=1)
        x_out = x_max+x_avg
        x_output = torch.cat((x_avg, x_max, x_out), dim=1) 
        x_output = x_output.transpose(1, 3) 
        x_output = self.bottleneck(x_output)
        x_output = x_output.transpose(1, 3) 
        y = x_output * x
        return y

##Decoder Upsample

In [None]:
class DecoderUNit(nn.Module):
  def __init__(self, inchannels, outchannels, size):
    super().__init__()
    self.up = nn.Upsample(size = size)
    self.conv1 = nn.Sequential(
        nn.Conv2d(inchannels[0], inchannels[0] , kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(inchannels[0]),
        nn.ReLU(),
    )
        # nn.Upsample(size = size),
        # nn.Conv2d(256 + inchannels, outchannels),
        # nn.BatchNorm2d(outchannels),
        # nn.Relu(),
    self.conv2 = nn.Sequential(
        nn.Conv2d(inchannels[1], outchannels, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(outchannels),
        nn.ReLU()      
    )
    # pass
  def forward(self, x, en = None, patch = None):
    if en is not None:
      # skip = self.skip(en, patch)
      # skip = rearrange(en, 'b (p1 p2) c -> b c p1 p2', p1 = patch, p2 = patch)
      x = x + en
      shortcut = x.clone()
      x = self.conv1(x)
      x = x + shortcut
    if en is not None:
      x = torch.cat([x,en], dim = 1 )
    x = self.up(x)
    x = self.conv2(x)

    return x

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

##Multi axis attention



In [None]:
import torch.nn as nn
import torch
import numpy as np
import einops

def block_images_einops(x, patch_size):
  """Image to patches."""
  batch, height, width, channels = x.shape
  grid_height = height // patch_size[0]
  grid_width = width // patch_size[1]
  x = einops.rearrange(
      x, "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
      gh=grid_height, gw=grid_width, fh=patch_size[0], fw=patch_size[1]) 
  return x


def unblock_images_einops(x, grid_size, patch_size):
  """patches to images."""
  x = einops.rearrange(
      x, "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
      gh=grid_size[0], gw=grid_size[1], fh=patch_size[0], fw=patch_size[1])
  return x


# MFI
class GetSpatialGatingWeights_2D_Multi_Scale_Cascade_Grid(nn.Module):
    """Get gating weights for cross-gating MLP block."""
    def __init__(self,nIn:int,Nout:int,H_size:int=128,W_size:int=128,input_proj_factor:int=2,dropout_rate:float=0.0,use_bias:bool=True,train_size:int=512):
        super(GetSpatialGatingWeights_2D_Multi_Scale_Cascade_Grid, self).__init__()
        
        self.H = H_size
        self.W = W_size
        self.IN = nIn
        self.OUT = Nout
        if train_size == 512:
            self.grid_size = [[8, 8], [4, 4], [2, 2]]
        else:
            self.grid_size = [[6, 6], [3, 3], [2, 2]]

        self.block_size = [[int(H_size / l[0]), int(W_size / l[1])] for l in self.grid_size]
        self.input_proj_factor = input_proj_factor
        self.dropout_rate = dropout_rate
        self.use_bias = use_bias
        self.dropout = nn.Dropout(self.dropout_rate)
        self.LayerNorm = nn.LayerNorm(self.IN)
        self.Linear_end = nn.Linear(self.IN,self.OUT)
        self.Gelu = nn.GELU()
        self.Linear_grid_MLP_1 = nn.Linear((self.grid_size[0][0]*self.grid_size[0][1]),(self.grid_size[0][0]*self.grid_size[0][1]),bias=use_bias)

        self.Linear_Block_MLP_1 = nn.Linear((self.block_size[0][0]*self.block_size[0][1]),(self.block_size[0][0]*self.block_size[0][1]),bias=use_bias)

        self.Linear_grid_MLP_2 = nn.Linear((self.grid_size[1][0] * self.grid_size[1][1]),
                                           (self.grid_size[1][0] * self.grid_size[1][1]), bias=use_bias)

        self.Linear_Block_MLP_2 = nn.Linear((self.block_size[1][0] * self.block_size[1][1]),
                                            (self.block_size[1][0] * self.block_size[1][1]), bias=use_bias)

        self.Linear_grid_MLP_3 = nn.Linear((self.grid_size[2][0] * self.grid_size[2][1]),
                                           (self.grid_size[2][0] * self.grid_size[2][1]), bias=use_bias)

        self.Linear_Block_MLP_3 = nn.Linear((self.block_size[2][0] * self.block_size[2][1]),
                                            (self.block_size[2][0] * self.block_size[2][1]), bias=use_bias)

    def forward(self, x): 
        n, h, w,num_channels = x.shape
        
        x = self.LayerNorm(x.float()) 
        x = self.Gelu(x)

       
        gh1, gw1 = self.grid_size[0]
        fh1, fw1 = h // gh1, w // gw1
        u1 = block_images_einops(x, patch_size=(fh1, fw1))
        u1 = u1.permute(0,3,2,1)

        u1 = self.Linear_grid_MLP_1(u1)
        u1 = u1.permute(0,3,2,1)
        u1 = unblock_images_einops(u1, grid_size=(gh1, gw1), patch_size=(fh1, fw1))

        fh1, fw1 = self.block_size[0]
        gh1, gw1 = h // fh1, w // fw1
        v1 = block_images_einops(u1, patch_size=(fh1, fw1))
        v1 = v1.permute(0, 1, 3, 2)
        v1 = self.Linear_Block_MLP_1(v1)
        v1 = v1.permute(0, 1, 3, 2)
        v1 = unblock_images_einops(v1, grid_size=(gh1, gw1), patch_size=(fh1, fw1))

        gh2, gw2 = self.grid_size[1]
        fh2, fw2 = h // gh2, w // gw2
        u2 = block_images_einops(v1, patch_size=(fh2, fw2)) 
        u2 = u2.permute(0, 3, 2, 1)

        u2 = self.Linear_grid_MLP_2(u2)
        u2 = u2.permute(0, 3, 2, 1)
        u2 = unblock_images_einops(u2, grid_size=(gh2, gw2), patch_size=(fh2, fw2))

        fh2, fw2 = self.block_size[1]
        gh2, gw2 = h // fh2, w // fw2
        v2 = block_images_einops(u2, patch_size=(fh2, fw2))
        v2 = v2.permute(0, 1, 3, 2)
        v2 = self.Linear_Block_MLP_2(v2)
        v2 = v2.permute(0, 1, 3, 2)
        v2 = unblock_images_einops(v2, grid_size=(gh2, gw2), patch_size=(fh2, fw2))

        gh3, gw3 = self.grid_size[2]
        fh3, fw3 = h // gh3, w // gw3
        u3 = block_images_einops(v2, patch_size=(fh3, fw3))  
        u3 = u3.permute(0, 3, 2, 1)

        u3 = self.Linear_grid_MLP_3(u3)
        u3 = u3.permute(0, 3, 2, 1)
        u3 = unblock_images_einops(u3, grid_size=(gh3, gw3), patch_size=(fh3, fw3))

        fh3, fw3 = self.block_size[2]
        gh3, gw3 = h // fh3, w // fw3
        v3 = block_images_einops(u3, patch_size=(fh3, fw3))
        v3 = v3.permute(0, 1, 3, 2)
        v3 = self.Linear_Block_MLP_3(v3)
        v3 = v3.permute(0, 1, 3, 2)
        v3 = unblock_images_einops(v3, grid_size=(gh3, gw3), patch_size=(fh3, fw3))

        x = self.Linear_end(v3)
        x = self.dropout(x)
        return x


class conv_T_y_2_x(nn.Module):
    """ Unified y Dimensional to x """
    def __init__(self,y_nIn,x_nOut):
        super(conv_T_y_2_x, self).__init__()
        self.x_c = x_nOut
        self.y_c = y_nIn
        self.convT = nn.ConvTranspose2d(in_channels=self.y_c, out_channels=self.x_c, kernel_size=(3,3),
                                        stride=(2, 2))
    def forward(self,x,y):
       
        y = self.convT(y)
        _, _, h, w, = x.shape
        y = nn.Upsample(size=(h, w), mode='bilinear', align_corners=True)(y)
        return y



class CrossGatingBlock(nn.Module):
    """Cross-gating MLP block."""
    def __init__(self,x_in:int,y_in:int,out_features:int,patch_size:[int,int],block_size:[int,int],grid_size:[int,int],dropout_rate:float=0.0,input_proj_factor:int=2,upsample_y:bool=True,use_bias:bool=True, train_size:int=512):
        super(CrossGatingBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(out_features),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, padding=1, groups = out_features, bias=False),
            nn.Dropout(p=0.1), # save load thi bo Dropout
            nn.BatchNorm2d(out_features),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_features, out_features, kernel_size=1)
        )
        self.IN_x = x_in
        self.IN_y = y_in
        self._h = patch_size[0]
        self._w = patch_size[1]
        self.features = out_features
        self.block_size=block_size
        self.grid_size = grid_size
        self.dropout_rate = dropout_rate
        self.input_proj_factor = input_proj_factor
        self.upsample_y = upsample_y
        self.use_bias = use_bias
        self.Conv1X1_x = nn.Conv2d(self.IN_x,self.features,(1,1))
        self.Conv1X1_y = nn.Conv2d(self.IN_x,self.features,(1,1))
        self.LayerNorm_x = nn.LayerNorm(self.features)
        self.LayerNorm_y = nn.LayerNorm(self.features)
        self.Linear_x = nn.Linear(self.features,self.features,bias=use_bias)
        self.Linear_y = nn.Linear(self.features,self.features,bias=use_bias)
        self.Gelu_x = nn.GELU()
        self.Gelu_y = nn.GELU()
        self.Linear_x_end = nn.Linear(self.features,self.features,bias=use_bias)
        self.Linear_y_end = nn.Linear(self.features,self.features,bias=use_bias)
        self.dropout_x = nn.Dropout(self.dropout_rate)
        self.dropout_y = nn.Dropout(self.dropout_rate)

        self.ConvT = conv_T_y_2_x(self.IN_y,self.IN_x)
        self.fun_gx = GetSpatialGatingWeights_2D_Multi_Scale_Cascade_Grid(nIn=self.features, Nout=self.features, H_size=self._h, W_size=self._w,
                                                 input_proj_factor=2, dropout_rate=dropout_rate, use_bias=True, train_size=train_size)

        self.fun_gy = GetSpatialGatingWeights_2D_Multi_Scale_Cascade_Grid(nIn=self.features, Nout=self.features, H_size=self._h, W_size=self._w,
                                                 input_proj_factor=2, dropout_rate=dropout_rate, use_bias=True, train_size=train_size)

    def forward(self, x):
    # Upscale Y signal, y is the gating signal.
        # x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1 = patch1, p2 = patch1)
        # ####
        # y = rearrange(y, 'b (p1 p2) c -> b c p1 p2', p1 = patch2, p2 = patch2)
        # if self.upsample_y:
           
        #     y = self.ConvT(x,y)

        x = self.Conv1X1_x(x)
        # y = self.Conv1X1_y(y)
        # assert y.shape == x.shape
        x = x.permute(0, 2, 3, 1)  # n x h x w x c
        # y = y.permute(0, 2, 3, 1)
        shortcut_x = x
        # shortcut_y = y
        # Get gating weights from X
        x = self.LayerNorm_x(x)
        x = self.Linear_x(x)
        x = self.Gelu_x(x)

        gx = self.fun_gx(x)
        # n x h x w x c
        # Get gating weights from Y
        # y = self.LayerNorm_y(y)
        # y = self.Linear_y(y)
        # y = self.Gelu_y(y)

        # gy = self.fun_gy(y)

        # y = y * gx
        # y = self.Linear_y_end(y)
        # y = self.dropout_y(y)
        # y = y + shortcut_y
        # x = x * gy  # gating x using y
        x = self.Linear_y_end(x)
        x = self.dropout_x(x)
        x = x  + shortcut_x  
        x = x.permute(0, 3, 1, 2)  # n x h x w x c --> n x c x h x w
        # y = y.permute(0, 3, 1, 2)
        # logit = torch.cat([x,y], dim = 1)
        x = self.conv(x)
        return x


In [None]:
import torch
import torch.nn as nn

'''pixel-level module'''


class SkipConnection(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # self.middle_layer_size_ratio = 1
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.GELU()
        ) 
        self.up = nn.Upsample(scale_factor = 2)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels*2, in_channels, kernel_size = 3, padding = 1, stride = 1),
            nn.ReLU(),
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, in_channels, kernel_size = 3, padding = 1, stride = 1),
            nn.ReLU(),
            nn.BatchNorm2d(in_channels),
        )
        # self.up = nn.Upsample(scale_factor=scale_factor)

    '''forward'''

    def forward(self, em1, em2, patch1, patch2):
        # x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1 = patch, p2 = patch)
        x1 = self.mlp(em1)
        x1 = rearrange(x1, 'b (p1 p2) c -> b c p1 p2', p1 = patch1, p2 = patch1)
        ####
        x2 = self.mlp(em2)
        x2 = rearrange(x2, 'b (p1 p2) c -> b c p1 p2', p1 = patch2, p2 = patch2)
        x2 = self.up(x2)
        ####
        x = torch.cat([x1, x2], dim = 1)
        x = self.conv(x)
        return x

In [None]:
import math

import torch
import torch.nn as nn

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def conv_bn_act(in_, out_, kernel_size,
                stride=1, groups=1, bias=True,
                eps=1e-3, momentum=0.01):
    return nn.Sequential(
        SamePadConv2d(in_, out_, kernel_size, stride, groups=groups, bias=bias),
        nn.BatchNorm2d(out_, eps, momentum),
        Swish()
    )


class SamePadConv2d(nn.Conv2d):
    """
    Conv with TF padding='same'
    https://github.com/pytorch/pytorch/issues/3867#issuecomment-349279036
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True, padding_mode="zeros"):
        super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias, padding_mode)

    def get_pad_odd(self, in_, weight, stride, dilation):
        effective_filter_size_rows = (weight - 1) * dilation + 1
        out_rows = (in_ + stride - 1) // stride
        padding_needed = max(0, (out_rows - 1) * stride + effective_filter_size_rows - in_)
        padding_rows = max(0, (out_rows - 1) * stride + (weight - 1) * dilation + 1 - in_)
        rows_odd = (padding_rows % 2 != 0)
        return padding_rows, rows_odd

    def forward(self, x):
        padding_rows, rows_odd = self.get_pad_odd(x.shape[2], self.weight.shape[2], self.stride[0], self.dilation[0])
        padding_cols, cols_odd = self.get_pad_odd(x.shape[3], self.weight.shape[3], self.stride[1], self.dilation[1])

        if rows_odd or cols_odd:
            x = F.pad(x, [0, int(cols_odd), 0, int(rows_odd)])

        return F.conv2d(x, self.weight, self.bias, self.stride,
                        padding=(padding_rows // 2, padding_cols // 2),
                        dilation=self.dilation, groups=self.groups)


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)


class SEModule(nn.Module):
    def __init__(self, in_, squeeze_ch):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_, squeeze_ch, kernel_size=1, stride=1, padding=0, bias=True),
            Swish(),
            nn.Conv2d(squeeze_ch, in_, kernel_size=1, stride=1, padding=0, bias=True),
        )

    def forward(self, x):
        return x * torch.sigmoid(self.se(x))


class DropConnect(nn.Module):
    def __init__(self, ratio):
        super().__init__()
        self.ratio = 1.0 - ratio

    def forward(self, x):
        if not self.training:
            return x

        random_tensor = self.ratio
        random_tensor += torch.rand([x.shape[0], 1, 1, 1], dtype=torch.float, device=x.device)
        random_tensor.requires_grad_(False)
        return x / self.ratio * random_tensor.floor()

class MBConv(nn.Module):
    def __init__(self, in_, out_, expand,
                 kernel_size, stride, skip,
                 se_ratio, dc_ratio=0.2):
        super().__init__()
        mid_ = in_ * expand
        self.expand_conv = conv_bn_act(in_, mid_, kernel_size=1, bias=False) if expand != 1 else nn.Identity()

        self.depth_wise_conv = conv_bn_act(mid_, mid_,
                                           kernel_size=kernel_size, stride=stride,
                                           groups=mid_, bias=False)

        self.se = SEModule(mid_, int(in_ * se_ratio)) if se_ratio > 0 else nn.Identity()

        self.project_conv = nn.Sequential(
            SamePadConv2d(mid_, out_, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_, 1e-3, 0.01)
        )

        # if _block_args.id_skip:
        # and all(s == 1 for s in self._block_args.strides)
        # and self._block_args.input_filters == self._block_args.output_filters:
        self.skip = skip and (stride == 1) and (in_ == out_)

        # DropConnect
        # self.dropconnect = DropConnect(dc_ratio) if dc_ratio > 0 else nn.Identity()
        # Original TF Repo not using drop_rate
        # https://github.com/tensorflow/tpu/blob/05f7b15cdf0ae36bac84beb4aef0a09983ce8f66/models/official/efficientnet/efficientnet_model.py#L408
        self.dropconnect = nn.Identity()

    def forward(self, inputs):
        expand = self.expand_conv(inputs)
        x = self.depth_wise_conv(expand)
        x = self.se(x)
        x = self.project_conv(x)
        if self.skip:
            x = self.dropconnect(x)
            x = x + inputs
        return x


##Build Model

In [None]:
from torch import nn
from torch.nn import functional as F


class MLP(nn.Module):
    def __init__(self, num_features, expansion_factor, dropout):
        super().__init__()
        num_hidden = num_features * expansion_factor
        self.fc1 = nn.Linear(num_features, num_hidden)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(num_hidden, num_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout1(F.gelu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return x


class TokenMixer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.num_patches = num_patches
        # self.shift = shiftmlp(num_features, num_features, num_features)
        #self.num_features = num_features
        self.norm = nn.LayerNorm(num_features)
        self.mlp = MLP(num_patches, expansion_factor, dropout)
        self.axialatt = AxialTokenMix(num_features, num_features, kernel_size = num_patches)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        # x = self.shift(x, self.num_patches, self.num_patches)
        x = self.norm(x)
        x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1 = self.num_patches, p2 = self.num_patches )
        x = self.axialatt(x)
        # x = x.transpose(1, 2)
        # 4,256,16,16
        # x.shape == (batch_size, num_features, num_patches)
        #x = self.mlp(x)
        # x = x.transpose(1, 2)
        # x = x.view(batch_size, -1, self.num_features)
        # x.shape == (batch_size, num_patches, num_features)
        x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c', p1 = self.num_patches, p2 = self.num_patches )
        out = x + residual
        return out


class ChannelMixer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(num_features)
        self.mlp = MLP(num_features, expansion_factor, dropout)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        x = self.norm(x)
        x = self.mlp(x)
        # x.shape == (batch_size, num_patches, num_features)
        out = x + residual
        return out


class MixerLayer(nn.Module):
    def __init__(self, num_features, num_patches, sqrt_num_patches, expansion_factor, dropout):
        super().__init__()
        self.token_mixer = TokenMixer(
            num_features, sqrt_num_patches , expansion_factor, dropout
        )
        self.channel_mixer = ChannelMixer(
            num_features, num_patches, expansion_factor, dropout
        )

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        x = self.token_mixer(x)
        x = self.channel_mixer(x)
        # x.shape == (batch_size, num_patches, num_features)
        return x


def check_sizes(image_size, patch_size):
    sqrt_num_patches, remainder = divmod(image_size, patch_size)
    assert remainder == 0, "`image_size` must be divisibe by `patch_size`"
    num_patches = sqrt_num_patches ** 2
    return sqrt_num_patches, num_patches

class MLPMixer(nn.Module):
    def __init__(
        self,
        image_size=256,
        patch_size=4,
        in_channels=1,
        num_features=64,
        expansion_factor=2,
        num_layers = [2,2,6,2],
        num_classes=2,
        dropout=0.1
    ):  
        self.sqrt_num_patches, self.num_patches = check_sizes(image_size, patch_size)
        super().__init__()
        self.patch_size = patch_size
        # per-patch fully-connected is equivalent to strided conv2d
        self.patcher = nn.Conv2d(
            in_channels, num_features, kernel_size=patch_size, stride=patch_size
        )###???
        # self.z = []
        self.mixers0 = nn.Sequential(
            *[  
                # for _ in range(num_layers)
                MixerLayer(num_features, self.num_patches,  self.sqrt_num_patches, expansion_factor, dropout)
                for _ in range(num_layers[0])
            ]
        )
        self.patch_merge0 = PatchMerging([64, 64], 64)
        self.mixers1 = nn.Sequential(
            *[  
                # for _ in range(num_layers)
                MixerLayer(num_features*2, self.num_patches//4,  self.sqrt_num_patches//2, expansion_factor, dropout)
                for _ in range(num_layers[1])
            ]
        )
        self.patch_merge1 = PatchMerging([32, 32], 128)
        self.mixers2 = nn.Sequential(
            *[  
                # for _ in range(num_layers)
                MixerLayer(num_features*4, self.num_patches//16,  self.sqrt_num_patches//4, expansion_factor, dropout)
                for _ in range(num_layers[2])
            ]
        )
        self.patch_merge2 = PatchMerging([16, 16], 256)
        self.mixers3 = nn.Sequential(
            *[  
                # for _ in range(num_layers)
                MixerLayer(num_features*8, self.num_patches//32,  self.sqrt_num_patches//8, expansion_factor, dropout)
                for _ in range(num_layers[3])
            ]
        )
        self.paspp = PASPP(512,512, output_stride = 4)
        # self.fea_conv = feature_conv(in_channels)
        # self.classifier = nn.Linear(num_features, num_classes)
        # self.skip4 = CrossGatingBlock(512, 512, 512, [8, 8], [8, 8], [4, 4], 0.1, upsample_y=False, train_size=512)
        self.skip3 = CrossGatingBlock(256, 256, 256, [16, 16], [8, 8], [4, 4], 0.1, upsample_y=False, train_size=512)
        self.skip2 = CrossGatingBlock(128, 128, 128, [32, 32], [8, 8], [4, 4], 0.1, upsample_y=False, train_size=512)
        self.skip1 = CrossGatingBlock(64, 64, 64,[64, 64], [8, 8], [4, 4], 0.1, upsample_y=False, train_size=512)
        self.conv_last = nn.Sequential(
            nn.Upsample(scale_factor=2),
            ConvBlock(32, num_classes)
            # nn.Conv2d(32, num_classes, 1)
        )
        self.du1 = DecoderUNit([512, 512], 256, 16)
        self.du2 = DecoderUNit([256, 512], 128, 32)
        self.du3 = DecoderUNit([128, 256 ], 64, 64)
        self.du4 = DecoderUNit([64, 128 ], 32, 128)

  
    def forward(self, x):
        # x1, x2, x3 = self.fea_conv(x)
        patches = self.patcher(x)
        batch_size, num_features, _ , _ = patches.shape
        patches = patches.permute(0, 2, 3, 1)
        patches = patches.view(batch_size, -1, num_features)
        # patches.shape == (batch_size, num_patches, num_features)
        
        #old embedding
        # embedding = self.mixers(patches)

        #new embedding:
        # stage 1
        embedding0 = self.mixers0(patches)# 64 64
        embedding1 = self.patch_merge0(embedding0)
        # print(embedding0.shape)
        # stage 2
        embedding1 = self.mixers1(embedding1)# 32 32
        embedding2 = self.patch_merge1(embedding1)
        # print(embedding1.shape)
        # stage 3
        embedding2 = self.mixers2(embedding2)# 16 16
        embedding3 = self.patch_merge2(embedding2)
        # print(embedding2.shape)
        # stage 4
        embedding = self.mixers3(embedding3)#8 8
        # print(embedding.shape)
        # embedding = self.lastmlp(embedding)
        embedding = rearrange(embedding, 'b (p1 p2) c -> b c p1 p2', p1 = self.sqrt_num_patches//8, p2 = self.sqrt_num_patches//8 )
        embedding = self.paspp(embedding)

        embedding2 = rearrange(embedding2, 'b (p1 p2) c -> b c p1 p2', p1 = self.sqrt_num_patches//4, p2 = self.sqrt_num_patches//4 )
        embedding1 = rearrange(embedding1, 'b (p1 p2) c -> b c p1 p2', p1 = self.sqrt_num_patches//2, p2 = self.sqrt_num_patches//2 )
        embedding0 = rearrange(embedding0, 'b (p1 p2) c -> b c p1 p2', p1 = self.sqrt_num_patches, p2 = self.sqrt_num_patches )
       
        
        
        # temp1, temp2 = torch.split(embedding, [256,256], dim = 1)
        # embedding = self.skip4(embedding)

        # temp1, temp2 = torch.split(embedding2, [128,128], dim = 1)
        embedding2 = self.skip3(embedding2)

        # temp1, temp2 = torch.split(embedding1, [64,64], dim = 1)
        embedding1 = self.skip2(embedding1)

        # temp1, temp2 = torch.split(embedding0, [32,32], dim = 1)
        embedding0 = self.skip1(embedding0)

        x = self.du1(embedding)
        x = self.du2(x, embedding2, 16)
        x = self.du3(x, embedding1, 32)
        x = self.du4(x, embedding0, 64)
        x = self.conv_last(x)

        return x

In [None]:
model = MLPMixer().cuda()

# Loss and Metric

In [None]:
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _WeightedLoss

In [None]:
import torch
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _WeightedLoss


EPSILON = 1e-32


class LogNLLLoss(_WeightedLoss):
    __constants__ = ['weight', 'reduction', 'ignore_index']

    def __init__(self, weight=None, size_average=None, reduce=None, reduction=None,
                 ignore_index=-100):
        super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, y_input, y_target):
        # y_input = torch.log(y_input + EPSILON)
        return cross_entropy(y_input, y_target, weight=self.weight,
                             ignore_index=self.ignore_index)


def classwise_iou(output, gt):
    """
    Args:
        output: torch.Tensor of shape (n_batch, n_classes, image.shape)
        gt: torch.LongTensor of shape (n_batch, image.shape)
    """
    #dims = (0, *range(2, len(output.shape)))
    #gt = torch.zeros_like(output).scatter_(1, gt[:, None, :], 1)
    output = torch.argmax(output, dim=1)
    intersection = output*gt
    union = output + gt - intersection
    #classwise_iou = (intersection.sum(dim=dims).float() + EPSILON) / (union.sum(dim=dims) + EPSILON)
    classwise_iou = (intersection.sum().float() + EPSILON) / (union.sum() + EPSILON)
    return classwise_iou


def classwise_f1(output, gt):
    """
    Args:
        output: torch.Tensor of shape (n_batch, n_classes, image.shape)
        gt: torch.LongTensor of shape (n_batch, image.shape)
    """

    epsilon = 1e-20
    n_classes = output.shape[1]

    output = torch.argmax(output, dim=1)
    #print(output)
    true_positives = torch.tensor([((output == i) * (gt == i)).sum() for i in range(n_classes)]).float()
    true_positives = true_positives[1].item()
    selected = ((output == 1)).sum().float()
    relevant = ((gt == 1)).sum().float()
    #selected = torch.tensor([(output == i).sum() for i in range(n_classes)]).float()
    #relevant = torch.tensor([(gt == i).sum() for i in range(n_classes)]).float()
    #print("relevant:",relevant)
    #print("selected:",selected)
    
    precision = (true_positives + epsilon) / (selected + epsilon)
    recall = (true_positives + epsilon) / (relevant + epsilon)
    #print(precision)
    #print(recall)
    classwise_f1 = 2 * (precision * recall + EPSILON) / (precision + recall + EPSILON)

    return classwise_f1
def classwise_dicescore(output, gt):
    """
    Args:
        output: torch.Tensor of shape (n_batch, n_classes, image.shape)
        gt: torch.LongTensor of shape (n_batch, image.shape)
    """
    epsilon = 1e-20
    n_classes = output.shape[1]

    output = torch.argmax(output, dim=1)
    #print(output)
    true_positives = torch.tensor([((output == i) * (gt == i)).sum() for i in range(n_classes)]).float()
    true_positives = true_positives[1].item()
    selected = ((output == 1)).sum().float()
    relevant = ((gt == 1)).sum().float()
    dice_score = 2 * true_positives / (selected + relevant)
    return dice_score

def make_weighted_metric(classwise_metric):
    """
    Args:
        classwise_metric: classwise metric like classwise_IOU or classwise_F1
    """

    def weighted_metric(output, gt, weights=None):

        # dimensions to sum over
        dims = (0, *range(2, len(output.shape)))

        # default weights
        if weights == None:
            weights = torch.ones(output.shape[1]) / output.shape[1]
        else:
            # creating tensor if needed
            if len(weights) != output.shape[1]:
                raise ValueError("The number of weights must match with the number of classes")
            if not isinstance(weights, torch.Tensor):
                weights = torch.tensor(weights)
            # normalizing weights
            weights /= torch.sum(weights)

        classwise_scores = classwise_metric(output, gt).cpu()

        return classwise_scores 

    return weighted_metric

In [None]:
# Implement Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=0, size_average=True,):
        super().__init__()
        self.gamma = gamma
        if isinstance(alpha, (float, int)): 
            self.alpha = torch.tensor([alpha, 1-alpha])
        if isinstance(alpha, (list)) :
            self.alpha = torch.tensor(alpha)
        self.size_average = size_average
        #self.alpha = self.alpha.to(device)
    def forward(self, inputs, targets):
        """
        Inputs:
        targets : shape (N, 1, H, W), dtype = long
        inputs : shape (N, C, H, W) - has propability for each class
        
        Returns:
        Focal loss between groundtruth and predict
        """
        if inputs.dim() > 2:
            B, C, H, W = inputs.shape
            inputs = inputs.contiguous().permute(0,2,3,1) # shape (B, H, W, C)
            inputs = inputs.contiguous().reshape(B*H*W,C)
        targets = targets.reshape(-1, 1) # shape (N*H*W, 1)
        
        logpt = F.log_softmax(inputs, dim = 1)
        logpt = logpt.gather(1, targets)
        logpt = logpt.view(-1) # shape (N*H*W)
        pt = logpt.exp()
        #print(targets.device)
        if self.alpha is not None:
            if self.alpha.type() != inputs.data.type():
                self.alpha = self.alpha.to(inputs.dtype)
            self.alpha = self.alpha.to(inputs.device)
            at = self.alpha.gather(0, targets.view(-1))
            logpt = logpt * at
        loss = -1. * (1 - pt)**self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()
        
        
        

In [None]:
# Implement Focal Loss
class DiceLoss(nn.Module):
    def __init__(self, size_average=True,):
        super().__init__()
    def forward(self, inputs, targets):
        """
        Inputs:
        targets : shape (N, 1, H, W), dtype = long
        inputs : shape (N, C, H, W) - has propability for each class
        
        Returns:
        Focal loss between groundtruth and predict
        """
        if inputs.dim() > 2:
            B, C, H, W = inputs.shape
            inputs = inputs.contiguous().permute(0,2,3,1) # shape (B, H, W, C)
            inputs = inputs.contiguous().reshape(B*H*W,C)
        targets = targets.reshape(-1, 1) # shape (N*H*W, 1)
        
        logpt = F.log_softmax(inputs, dim = 1)
        logpt = logpt.gather(1, targets)
        logpt = logpt.view(-1) # shape (N*H*W)
        pt = logpt.exp()
        #print(targets.device)
        pt = pt.view(-1) # shape (N*H*W)
        intersection = (pt * targets.view(-1)).sum()
        #print(targets.device)
        dice = (2. * intersection + 1e-32) / (pt.sum() + targets.sum() + 1e-32)
        return 1 - dice        
        

In [None]:
class CEDiceloss(nn.Module):
    def __init__(self, alpha = 0.5):
        super().__init__()
        self.alpha = alpha
    def forward(self, inputs, targets):
        criterion1 = DiceLoss()
        criterion2 = nn.CrossEntropyLoss()
        return self.alpha * criterion1(inputs, targets) + (1 - self.alpha) * criterion2(inputs, targets)

In [None]:
criterion = FocalLoss(0.4, 0.8)
inputs_test = torch.rand(1, 2, 128, 128).to("cuda")
#pred_test = torch.randint(0, 1, size=(1, 128, 128)).to("cuda")
pred_test = torch.ones(1,128,128, dtype=torch.int64).to("cuda")
criterion(inputs_test, pred_test)

tensor(0.2562, device='cuda:0')

In [None]:
#output, gt = torch.ones(3, 2, 5, 5), torch.ones(3, 5, 5).long()
gt = torch.ones(3, 5, 5).long()
output = torch.stack([torch.zeros(3,5,5), torch.ones(3,5,5)], dim=1)
print(classwise_dicescore(output, gt))

tensor(1.)


In [None]:
classwise_iou(output, gt)
#print(output.shape)

tensor(1.)

# Train and Test

In [None]:
from tqdm import tqdm

In [None]:
lr = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_EPOCHS = 200
IMG_SIZE = 256
IMG_CHANEL = 3
save_freq = 1000

In [None]:
# def save_checkpoint(state, filename="my_checkpoint_ceil.pth.tar"):
#     print("=>Saving checkpoint")
#     torch.save(state, filename)
# def load_checkpoint(checkpoint, model):
#     print("=> Loading checkpoint")
#     model.load_state_dict(checkpoint["state_dict"])

In [None]:
def save_checkpoint(state, filename="my_checkpoint_ceil_1.pth.tar"):
    print("=>Saving checkpoint")
    torch.save(state, filename)
def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

In [None]:
# model = Permute_Unet(image_size=256,patch_size=8, depth=24, segments=16, dim = 256)
# model.cuda();

In [None]:
criterion = CEDiceloss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)
scaler = torch.cuda.amp.GradScaler()
swa_model = torch.optim.swa_utils.AveragedModel(model)
swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=1e-5)
swa_model.cuda()
swa_start = 2000
#criterion = LogNLLLoss()
reset_seed()

In [None]:
#optimizer_1 = torch.optim.Adam(model_fake.parameters(), lr=1e-4, weight_decay=1e-4)
#swa_scheduler = torch.optim.swa_utils.SWALR(optimizer_1, swa_lr=1e-4)

In [None]:
def train_fn(loader, model,optimizer, loss_fn, scaler):
    
    model.train()
    train_running_loss = 0
    my_f1 = 0
    my_iou = 0
    counter = 0
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.long().to(DEVICE)
        # forward
        with torch.cuda.amp.autocast():
            prediction = model(data)
            loss = loss_fn(prediction, targets)
        tmp = prediction.detach().cpu()
        tmp2 = targets.detach().cpu()
        my_f1 += classwise_f1(tmp, tmp2).item()
        my_iou += classwise_iou(tmp, tmp2).item()
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # update tqdm loop
        loop.set_postfix(loss = loss.item())
        train_running_loss += loss.item()
        counter += 1
    return train_running_loss / counter, my_f1/counter, my_iou/counter

def check_accuracy(loader, model, loss_fn, device="cuda"):
    my_f1 = 0
    my_iou = 0
    my_dicescore = 0
    val_running_loss = 0
    model.eval()
    with torch.no_grad():
        for X, y in loader:
            batch_size = X.shape[0]
            X = X.to(device)
            y = y.long().to(device)
            preds = model(X)
            loss = loss_fn(preds, y)
            val_running_loss += loss.item()
            tmp = preds.detach().cpu()
            tmp2 = y.detach().cpu()
            my_f1 += classwise_f1(tmp, tmp2).item()
            my_iou += classwise_iou(tmp, tmp2).item()
            #my_dicescore += classwise_dicescore(tmp, tmp2).item()
            #my_dicescore += classwise_dicescore(tmp, tmp2).item()
    model.train()
    print(f"IoU score: {my_iou/len(loader)}")
    print(f"F1 score: {my_f1/len(loader)} ")
    return val_running_loss/len(loader),my_f1/len(loader),my_iou/len(loader)


In [None]:
def train_fn_swa(loader, model, swa_model, optimizer, loss_fn, scaler):
    model.train()
    train_running_loss = 0
    my_f1 = 0
    my_iou = 0
    counter = 0
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.long().to(DEVICE)

        # forward

        with torch.cuda.amp.autocast():
            prediction = model(data)
            loss = loss_fn(prediction, targets)
        
        tmp = prediction.detach().cpu()
        tmp2 = targets.detach().cpu()
        my_f1 += classwise_f1(tmp, tmp2).item()
        my_iou += classwise_iou(tmp, tmp2).item()
        
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss = loss.item())
        train_running_loss += loss.item()
        counter += 1
    swa_model.update_parameters(model)
    swa_scheduler.step()
    bn_update(swa_model, train_loader)
    return train_running_loss / counter, my_f1/counter, my_iou/counter

# batchnormalize update running mean + running var
def bn_update(swa_model, loader):
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device=DEVICE)
            #target.to(device=DEVICE)
            _ = swa_model(data)

In [None]:
loss_history = {"train":[],"test":[]}
acc_history = {"train_f1":[],"test_f1":[],"train_iou":[],"test_iou":[]}
checkpoint_history = {}
best_accuracy = 0

In [None]:
def train_normal():
    for epoch in range(0, 300):

        #model.cuda()
        print(f"=============Epoch {epoch}============>")
        if (epoch < swa_start) :
            train_loss,train_f1,train_iou = train_fn(train_loader, model, optimizer, criterion, scaler)
        else :
            train_loss,train_f1,train_iou = train_fn_swa(train_loader, model, swa_model, optimizer, criterion, scaler)
        if epoch == 10:
            for param in model.parameters():
                param.requires_grad =True


        if (epoch % save_freq) == 0:
            if (epoch >=  swa_start):
                checkpoint = {
                    "state_dict": swa_model.module.state_dict(),
                    "optimizer" : optimizer.state_dict(),
                    "epoch" : epoch + 1,
                }
                #checkpoint_history["epoch "+str(epoch+1)] = checkpoint
                save_checkpoint(checkpoint)
                print("test score:")
                test_loss,test_f1,test_iou = check_accuracy(test_loader, swa_model, criterion, DEVICE)
                if test_f1 > best_accuracy:
                    best_accuracy = test_f1
                    best_checkpoint = epoch
                    name = f"dice score {test_f1:.4f},test_score {test_iou:.4f},epoch {epoch}".replace(".",",")
                    save_checkpoint(checkpoint, filename=name+".pth.tar")
                print("train score:")
                train_loss, train_f1, train_iou = check_accuracy(train_loader, swa_model, criterion, DEVICE)
                print("train score:")
                loss_history["test"].append(test_loss)
                loss_history["train"].append(train_loss)
                acc_history["train_f1"].append(train_f1)
                acc_history["train_iou"].append(train_iou)
                acc_history["test_f1"].append(test_f1)
                acc_history["test_iou"].append(test_iou)
            else :
                checkpoint = {
                    "state_dict": model.state_dict(),
                    "optimizer" : optimizer.state_dict(),
                    "epoch" : epoch + 1,
                }
                #checkpoint_history["epoch "+str(epoch+1)] = checkpoint
                save_checkpoint(checkpoint)
                print("test score:")
                test_loss,test_f1,test_iou = check_accuracy(test_loader, model, criterion, DEVICE)
                if test_f1 > best_accuracy:
                    best_accuracy = test_f1
                    best_checkpoint = epoch
                    name = f"dice score {test_f1:.4f},test_score {test_iou:.4f},epoch {epoch}".replace(".",",")
                    save_checkpoint(checkpoint, filename=name+".pth.tar")
                print("train score:")
                train_loss, train_f1, train_iou = check_accuracy(train_loader, model, criterion, DEVICE)
                print("train score:")
                loss_history["test"].append(test_loss)
                loss_history["train"].append(train_loss)
                acc_history["train_f1"].append(train_f1)
                acc_history["train_iou"].append(train_iou)
                acc_history["test_f1"].append(test_f1)
                acc_history["test_iou"].append(test_iou)



In [None]:
def train_eff():
    best_accuracy = 0
    for epoch in range(0, 300):

        #model.cuda()
        print(f"=============Epoch {epoch}============>")
        if (epoch < swa_start) :
            train_loss,train_f1,train_iou = train_fn(train_loader, model, optimizer, criterion, scaler)
        else :
            train_loss,train_f1,train_iou = train_fn_swa(train_loader, model, swa_model, optimizer, criterion, scaler)
        print("train score:")
        print(f"IoU score: {train_iou}")
        print(f"Dice score: {train_f1}")
        if epoch == 10:
            for param in model.parameters():
                param.requires_grad =True



        if (epoch >=  swa_start):
            #checkpoint_history["epoch "+str(epoch+1)] = checkpoint
            if (epoch % 1000 == 0 and epoch > 0):
                save_checkpoint(checkpoint,filename=f"epoch {epoch}.pth.tar")
            print("test score:")
            test_loss,test_f1,test_iou = check_accuracy(test_loader, swa_model, criterion, DEVICE)
            if test_f1 > best_accuracy:
                best_accuracy = test_f1
                checkpoint = {
                "state_dict": swa_model.module.state_dict(),
                "epoch" : epoch + 1,
            }
                name = f"dice score {test_f1:.4f},test_score {test_iou:.4f},epoch {epoch}".replace(".",",")
                save_checkpoint(checkpoint, filename=name+".pth.tar")
            loss_history["test"].append(test_loss)
            loss_history["train"].append(train_loss)
            acc_history["train_f1"].append(train_f1)
            acc_history["train_iou"].append(train_iou)
            acc_history["test_f1"].append(test_f1)
            acc_history["test_iou"].append(test_iou)
        else :
            #checkpoint_history["epoch "+str(epoch+1)] = checkpoint
            if (epoch % 1000 == 0 and epoch > 0):
                save_checkpoint(checkpoint,filename = f"epoch {epoch}.pth.tar")
            print("test score:")
            test_loss,test_f1,test_iou = check_accuracy(test_loader, model, criterion, DEVICE)
            if test_f1 > best_accuracy:
                best_accuracy = test_f1
                checkpoint = {
                "state_dict": model.state_dict(),
                "epoch" : epoch + 1,
            }
                name = f"dice score {test_f1:.4f},test_score {test_iou:.4f},epoch {epoch}".replace(".",",")
                save_checkpoint(checkpoint, filename=name+".pth.tar")
            loss_history["test"].append(test_loss)
            loss_history["train"].append(train_loss)
            acc_history["train_f1"].append(train_f1)
            acc_history["train_iou"].append(train_iou)
            acc_history["test_f1"].append(test_f1)
            acc_history["test_iou"].append(test_iou)



