In [1]:
import torch
import torch.nn as nn   

In [2]:
class patchEmbed(nn.Module):
    '''split image into patches and then embed them
    Paramenters:
        img_size: int, size of image (square image)
        patch_size: int, size of patch
        in_chans: int, number of input channels
        embed_dim: int, dimension of embedding
    '''
    def __init__(self,img_size:int,patch_size:int,in_chans:int = 3,embed_dim:int = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.prog = nn.Conv2d(
            in_channels = in_chans,
            out_channels = embed_dim,
            kernel_size = patch_size,
            stride = patch_size
        )
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.prog(x)    
        x = x.flatten(2)
        x = x.transpose(1,2)
        return x

In [3]:
class Attention(nn.Module):
    '''attention mechanism
    parameters:
        dim: int, dimension of input
        num_heads: int, number of heads
        qkv_bias: bool, whether to include bias in qkv projection
        attn_p: float, dropout probability for attention
        proj_p: float, dropout probability for projection
    '''
    def __init__(self,dim:int,num_heads:int = 12,qkv_bias:bool = True,attn_p:float = 0.,proj_p:float = 0.):
        super().__init__()
        self.n_heads = num_heads
        self.dim = dim 
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim,dim*3,bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim,dim)
        self.proj_drop = nn.Dropout(proj_p)
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        n_samples,n_tokens,dim = x.shape
        
        if dim != self.dim:
            raise ValueError(f'Input dim {dim} should be {self.dim}')
        
        qkv = self.qkv(x)
        qkv = qkv.reshape(n_samples,n_tokens,3,self.n_heads,self.head_dim)
        qkv = qkv.permute(2,0,3,1,4)
        q,k,v = qkv[0],qkv[1],qkv[2]
        k_t = k.transpose(-2,-1)
        dp = (q @ k_t) * self.scale
        attn = dp.softmax(dim = -1)
        attn = self.attn_drop(attn)
        weighted_avg = attn @ v
        weighted_avg = weighted_avg.transpose(1,2)
        weighted_avg = weighted_avg.flatten(2)
        x = self.proj(weighted_avg)
        x = self.proj_drop(x)
        return x 

In [4]:
class MLP(nn.Module):
    '''multi-layer perceptron
    parameters:
        in_features: int, number of input features
        hidden_features: int, number of hidden features
        out_features: int, number of output features
        p: float, dropout probability
    '''
    def __init__(self,in_features:int,hidden_features:int,out_features:int,p:float = 0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features,hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features,out_features)
        self.drop = nn.Dropout(p)
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x    

In [5]:
class Block(nn.Module):
    '''tranformer block
    parameters:
        dim: int, dimension of input
        num_heads: int, number of heads
        mlp_ratio: int, ratio of hidden to input dimension
        qkv_bias: bool, whether to include bias in qkv projection
        p: float, dropout probability
    '''
    def __init__(self,dim:int,num_heads:int,mlp_ratio:int = 4,qkv_bias:bool = True,p:float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim,num_heads = num_heads,qkv_bias = qkv_bias,attn_p = p,proj_p = p)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features = dim,hidden_features = dim*mlp_ratio,out_features = dim,p = p)
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.norm1(x)
        x = x + self.attn(x)
        x = self.norm2(x)
        x = x + self.mlp(x)
        return x    

In [6]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels:int,out_channels:int,kernel_size:int = 3,padding:int = 1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels = in_channels,out_channels = out_channels,kernel_size = kernel_size,padding = padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x

class DeconvBlock(nn.Module):
    def __init__(self,in_channels:int,out_channels:int):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_channels = in_channels,out_channels = out_channels,kernel_size = 2,stride = 2,padding = 0)
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.deconv(x)
        return x                

In [23]:
class Unetr2D(nn.Module):
    def __init__(self,
                 img_size: int,
                 patch_size: int,
                 in_chans: int = 3,
                 num_classes: int = 1,
                 embed_dim: int = 768,
                 depth: int = 12,
                 num_heads: int = 12,
                 mlp_ratio: int = 4,
                 qkv_bias: bool = True,
                 p: float = 0.):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        self.patch_embed = patchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        self.pos_drop = nn.Dropout(p)
        self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, p=p) for _ in range(depth)])
        
        self.deconv1 = DeconvBlock(in_channels=embed_dim, out_channels=512)
        
        self.deconv_conv1 = nn.Sequential(
            DeconvBlock(in_channels=embed_dim, out_channels=512),
            ConvBlock(in_channels=512, out_channels=512))
        
        self.deconv_conv2 = nn.Sequential(
            DeconvBlock(in_channels=embed_dim, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256),
            DeconvBlock(in_channels=256, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256))
        
        self.deconv_conv3 = nn.Sequential(
            DeconvBlock(in_channels=embed_dim, out_channels=128),
            ConvBlock(in_channels=128, out_channels=128),
            DeconvBlock(in_channels=128, out_channels=128),
            ConvBlock(in_channels=128, out_channels=128),
            DeconvBlock(in_channels=128, out_channels=128),
            ConvBlock(in_channels=128, out_channels=128))
        
        self.conv_conv_deconv1 = nn.Sequential(
            ConvBlock(in_channels=1024, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256),
            DeconvBlock(in_channels=256, out_channels=256))
        
        self.conv_conv_deconv2 = nn.Sequential(
            ConvBlock(in_channels=512, out_channels=128),
            ConvBlock(in_channels=128, out_channels=128),
            DeconvBlock(in_channels=128, out_channels=128))
        
        self.conv_conv_deconv3 = nn.Sequential(
            ConvBlock(in_channels=256, out_channels=64),
            ConvBlock(in_channels=64, out_channels=64),
            DeconvBlock(in_channels=64, out_channels=64))
        
        self.conv_conv1 = nn.Sequential(
            ConvBlock(in_channels=3, out_channels=64),
            ConvBlock(in_channels=64, out_channels=64))
        
        self.final_conv = nn.Sequential(
            ConvBlock(in_channels=128, out_channels=32),
            ConvBlock(in_channels=32, out_channels=32),
            nn.Conv2d(in_channels=32, out_channels=num_classes, kernel_size=1, padding=0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orign = x
        B, C, H, W = x.shape
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        skips = []
        for i, block in enumerate(self.blocks):
            x = block(x)
            if i in [2, 5, 8]:  
                skips.append(x)

        x = x.reshape(B, H // self.patch_size, W // self.patch_size, self.embed_dim).permute(0, 3, 1, 2)
        
        x= self.deconv1(x)
        skip2 = skips[2].reshape(B, H // self.patch_size, W // self.patch_size, self.embed_dim).permute(0, 3, 1, 2) 
        skip2 = self.deconv_conv1(skip2)
        x= torch.cat([x, skip2], dim=1)
        
        skip1 = skips[1].reshape(B, H // self.patch_size, W // self.patch_size, self.embed_dim).permute(0, 3, 1, 2)
        skips1 = self.deconv_conv2(skip1)
        x = self.conv_conv_deconv1(x)
        x= torch.cat([x, skips1], dim=1)
        
        skip0 = skips[0].reshape(B, H // self.patch_size, W // self.patch_size, self.embed_dim).permute(0, 3, 1, 2)
        skips0 = self.deconv_conv3(skip0)
        x = self.conv_conv_deconv2(x)
        x= torch.cat([x, skips0], dim=1)
        
        orign = self.conv_conv1(orign)
        x = self.conv_conv_deconv3(x)
        x= torch.cat([x, orign], dim=1)
        
        x = self.final_conv(x)
    
        return x

In [24]:
model=Unetr2D(256,16)
x=torch.randn(5,3,256,256)
y=model(x)
print(y.shape)

torch.Size([5, 1, 256, 256])
