In [2]:
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.nn import functional as F
from timm.models.layers import DropPath, to_2tuple
import math

  from .autonotebook import tqdm as notebook_tqdm


# Base Codes

In [3]:

class DWConvLKA(nn.Module):
    def __init__(self, dim=768):
        super(DWConvLKA, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        x = self.dwconv(x)
        return x


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.dwconv = DWConvLKA(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)
        self.linear = linear
        if self.linear:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        if self.linear:
            x = self.relu(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class AttentionModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv_spatial = nn.Conv2d(
            dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
        self.conv1 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)
        return u * attn


class SpatialAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.proj_1 = nn.Conv2d(d_model, d_model, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = AttentionModule(d_model)
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x


class LKABlock(nn.Module):

    def __init__(self,
                 dim,
                 mlp_ratio=4.,
                 drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 linear=False):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)  # build_norm_layer(norm_cfg, dim)[1]
        self.attn = SpatialAttention(dim)
        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = nn.LayerNorm(dim)  # build_norm_layer(norm_cfg, dim)[1]
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop, linear=linear)
        layer_scale_init_value = 1e-2
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        # B, N, C = x.shape
        # x = x.permute(0, 2, 1).view(B, C, H, W)
        y = x.permute(0, 2, 3, 1)  # b h w c, because norm requires this
        y = self.norm1(y)
        y = y.permute(0, 3, 1, 2)  # b c h w, because attn requieres this
        y = self.attn(y)
        y = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * y
        y = self.drop_path(y)
        x = x + y
        # x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
        #                       * self.attn(self.norm1(x)))

        y = x.permute(0, 2, 3, 1)  # b h w c, because norm requires this
        y = self.norm2(y)
        y = y.permute(0, 3, 1, 2)  # b c h w, because attn requieres this
        y = self.mlp(y)
        y = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * y
        y = self.drop_path(y)
        x = x + y
        # x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
        #                       * self.mlp(self.norm2(x)))
        # x = x.view(B, C, N).permute(0, 2, 1)
        # print("LKA return shape: {}".format(x.shape))
        return x



In [4]:
class PatchExpand(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        # print("x_shape-----",x.shape)
        H, W = self.input_resolution
        x = self.expand(x)

        B, L, C = x.shape
        # print(x.shape)
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(x, "b h w (p1 p2 c)-> b (h p1) (w p2) c", p1=2, p2=2, c=C // 4)
        x = x.view(B, -1, C // 4)
        x = self.norm(x.clone())

        return x


class FinalPatchExpand_X4(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.dim_scale = dim_scale
        self.expand = nn.Linear(dim, 16 * dim, bias=False)
        self.output_dim = dim
        self.norm = norm_layer(self.output_dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(
            x, "b h w (p1 p2 c)-> b (h p1) (w p2) c", p1=self.dim_scale, p2=self.dim_scale, c=C // (self.dim_scale ** 2)
        )
        x = x.view(B, -1, self.output_dim)
        x = self.norm(x.clone())

        return x


# Prompt Module

In [16]:
class PromptGenBlock(nn.Module):
    def __init__(self,prompt_dim=48,prompt_len=5,prompt_size = 96,lin_dim = 192):
        super(PromptGenBlock,self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
        self.linear_layer = nn.Linear(lin_dim,prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
        

    def forward(self,x):
        B,C,H,W = x.shape
        emb = x.mean(dim=(-2,-1))
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
        prompt = torch.sum(prompt,dim=1)
        prompt = F.interpolate(prompt,(H,W),mode="bilinear")
        prompt = self.conv3x3(prompt)
        

        return prompt

In [20]:
torch.randn(1,96,56,56).mean(dim=(-2,-1)).shape

torch.Size([1, 96])

In [53]:
class PromptGenBlockSina(nn.Module):
    def __init__(self,prompt_dim=48,prompt_len=5,prompt_size = 96,lin_dim = 192):
        super(PromptGenBlockSina,self).__init__()

        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
        self.linear_layer = nn.Linear(lin_dim,prompt_len)

        self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
        

    def forward(self,x):
        B,C,H,W = x.shape
        emb = x.mean(dim=(-2,-1)) # B, C
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1) # B, C
        
        prompt_ = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        print(prompt_.shape)
        
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)

        prompt__ = self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
        print(prompt__.shape)
        
        prompt = torch.sum(prompt,dim=1)
        print(prompt.shape)
        prompt = F.interpolate(prompt,(H,W),mode="bilinear")
        prompt = self.conv3x3(prompt)
        

        return prompt

In [None]:
reduction_ratio = [32, 16, 8, 4]

In [None]:
class PromptGenBlockSinaV1(nn.Module):
    def __init__(self,prompt_dim=48,prompt_ratio=4,input_size = 96,lin_dim = 192):
        super(PromptGenBlockSinaV1,self).__init__()

        prompt_len = prompt_ratio * prompt_dim 

        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len, input_size, input_size))

        self.linear_layer = nn.Linear(lin_dim,prompt_len)

        self.conv3x3 = nn.Conv2d(prompt_len,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
        

    def forward(self,x):
        B,C,H,W = x.shape

        emb = x.mean(dim=(-2,-1)) # B, C

        
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1) # B, C

        
        prompt_weights = prompt_weights.unsqueeze(-1).unsqueeze(-1).expand_as(self.prompt_param)

        prompt = prompt_weights * self.prompt_param

        prompt = self.conv3x3(prompt)

        print(prompt.shape)
        

        return prompt

In [72]:
model = PromptGenBlockSinaV1(prompt_dim=768,
                       prompt_ratio= 4,
                       input_size=7,
                       lin_dim=768)

model(torch.randn(1,768, 7, 7)).shape, calculate_params_in_millions(model)

torch.Size([1, 768, 7, 7])


(torch.Size([1, 768, 7, 7]), 23.74656)

In [None]:
model = PromptGenBlockSina(prompt_dim=96,
                       prompt_len= 5,
                       input_size=96,
                       lin_dim=96)

model(torch.randn(1,96, 7, 7)).shape, calculate_params_in_millions(model)

In [7]:
class Attention_st(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention_st, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   



        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention_st_cross(nn.Module):
    def __init__(self, dim, num_heads, bias):
        dim = dim //2
        super(Attention_st_cross, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.kv = nn.Conv2d(dim, 2*dim*2, kernel_size=1, bias=bias)
        self.q = nn.Conv2d(dim, 2*dim, kernel_size=1, bias=bias)
        self.kv_dwconv = nn.Conv2d(2*dim*2, 2*dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
        self.project_out = nn.Conv2d(2*dim, 2*dim, kernel_size=1, bias=bias)
        

    def forward(self, x):
        b,c,h,w = x.shape

        kv = self.kv_dwconv(self.kv(x[:, :c//2]))
        k,v = kv.chunk(2, dim=1)   
        q = self.q(x[:, :c//2])

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

In [8]:
import numbers

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias

In [9]:
def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class LayerNormst(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNormst, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)
    
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

In [10]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, cross= False):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNormst(dim, LayerNorm_type)
        if not cross:
            self.attn = Attention_st(dim, num_heads, bias)
        else:
            self.attn = Attention_st_cross(dim, num_heads, bias)
        self.norm2 = LayerNormst(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x

# Base Decoder

In [11]:
class MyDecoderLayerLKA(nn.Module):
    def __init__(
            self, input_size: tuple, in_out_chan: tuple, n_class=9,
            norm_layer=nn.LayerNorm, is_last=False
    ):
        super().__init__()
        out_dim = in_out_chan[0]
        x1_dim = in_out_chan[1]
        
        if not is_last:
            self.x1_linear = nn.Linear(x1_dim, out_dim)
            #self.ag_attn = MultiScaleGatedAttn(dim=x1_dim)
            self.ag_attn_norm = nn.LayerNorm(out_dim)

            self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
            self.last_layer = None
        else:
            self.x1_linear = nn.Linear(x1_dim, out_dim)
            #self.ag_attn = MultiScaleGatedAttn(dim=x1_dim)
            self.ag_attn_norm = nn.LayerNorm(out_dim)

            self.layer_up = FinalPatchExpand_X4(
                input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer
            )
            self.last_layer = nn.Conv2d(out_dim, n_class, 1)

        
        self.layer_lka_1 = LKABlock(dim=out_dim)
        ## Prompt Module must be located here.

        self.layer_lka_2 = LKABlock(dim=out_dim)

        

        def init_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Conv2d):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

        init_weights(self)

    def forward(self, x1, x2=None):
        if x2 is not None:  # skip connection exist
            x2 = x2.contiguous()
            # b, c, h, w = x1.shape
            b2, h2, w2, c2 = x2.shape  # e.g: 1 28 28 320, 1 56 56 128
            x2 = x2.view(b2, -1, c2)  # e.g: 1 784 320, 1 3136 128

            x1_expand = self.x1_linear(x1)  # e.g: 1 784 256 --> 1 784 320, 1 3136 160 --> 1 3136 128

            x2_new = x2.view(x2.size(0), x2.size(2), x2.size(1) // w2, x2.size(1) // h2) # B, C, H, W
            

            x1_expand = x1_expand.view(x2.size(0), x2.size(2), x2.size(1) // w2, x2.size(1) // h2) # B, C, H, W

            # print(f'the x1_expand shape is: {x1_expand.shape}\n\t the x2_new shape is: {x2_new.shape}')

            #attn_gate = self.ag_attn(x=x2_new, g=x1_expand)  # B C H W

            cat_linear_x = x1_expand + x2_new  # B C H W
            cat_linear_x = cat_linear_x.permute(0, 2, 3, 1)  # B H W C
            cat_linear_x = self.ag_attn_norm(cat_linear_x)  # B H W C

            cat_linear_x = cat_linear_x.permute(0, 3, 1, 2).contiguous()  # B C H W

            tran_layer_1 = self.layer_lka_1(cat_linear_x)
            # print(tran_layer_1.shape)
            tran_layer_2 = self.layer_lka_2(tran_layer_1)

            tran_layer_2 = tran_layer_2.view(tran_layer_2.size(0), tran_layer_2.size(3) * tran_layer_2.size(2),
                                             tran_layer_2.size(1))
            if self.last_layer:
                out = self.last_layer(
                    self.layer_up(tran_layer_2).view(b2, 4 * h2, 4 * w2, -1).permute(0, 3, 1, 2))  # 1 9 224 224
            else:
                out = self.layer_up(tran_layer_2)  # 1 3136 160
        else:
            out = self.layer_up(x1)
        return out

In [52]:
input_0 = torch.randn(1, 96, 56, 56).cuda() # skip 3
input_1 = torch.randn(1, 192, 28, 28).cuda() # skip 2
input_2 = torch.randn(1, 384, 14, 14).cuda() # skip 1
input_3 = torch.randn(1, 768, 7, 7).cuda() # X

In [53]:
decoder_3 = MyDecoderLayerLKA(input_size=(7,7),
                               in_out_chan=([768, 768])).cuda()

decoder_2 = MyDecoderLayerLKA(input_size=(14,14),
                               in_out_chan=([384, 384])).cuda()
decoder_1 = MyDecoderLayerLKA(input_size=(28,28),
                               in_out_chan=([192, 192])).cuda()
decoder_0 = MyDecoderLayerLKA(input_size=(56, 56), 
                              in_out_chan=([96, 96]),
                              is_last=True).cuda()

In [54]:
b, c, _, _ = input_3.shape

output_3 = decoder_3(input_3.permute(0, 2, 3, 1).view(b,-1, c))
output_2 = decoder_2(output_3, input_2.permute(0,2 , 3, 1))
output_1 = decoder_1(output_2, input_1.permute(0, 2, 3, 1))
output_0 = decoder_0(output_1, input_0.permute(0, 2, 3, 1))

In [55]:
print(f"the output 3 shape: {output_3.shape}\nthe output 2 shape: {output_2.shape} \nthe output 1 shape: {output_1.shape} \nthe output 0 shape: {output_0.shape}")

the output 3 shape: torch.Size([1, 196, 384])
the output 2 shape: torch.Size([1, 784, 192]) 
the output 1 shape: torch.Size([1, 3136, 96]) 
the output 0 shape: torch.Size([1, 9, 224, 224])


In [12]:
class MyDecoderLayerLKA_Prompt(nn.Module):
    def __init__(
            self, input_size: tuple, in_out_chan: tuple, n_class=9,
            norm_layer=nn.LayerNorm, is_last=False
    ):
        super().__init__()
        out_dim = in_out_chan[0]
        x1_dim = in_out_chan[1]
        
        if not is_last:
            self.x1_linear = nn.Linear(x1_dim, out_dim)
            #self.ag_attn = MultiScaleGatedAttn(dim=x1_dim)
            self.ag_attn_norm = nn.LayerNorm(out_dim)

            self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
            self.last_layer = None
        else:
            self.x1_linear = nn.Linear(x1_dim, out_dim)
            #self.ag_attn = MultiScaleGatedAttn(dim=x1_dim)
            self.ag_attn_norm = nn.LayerNorm(out_dim)

            self.layer_up = FinalPatchExpand_X4(
                input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer
            )
            self.last_layer = nn.Conv2d(out_dim, n_class, 1)

        
        self.layer_lka_1 = LKABlock(dim=out_dim)
        ## Prompt Module must be located here.

        #dim_p = int(out_dim * 0.75)
        dim_p = out_dim
        # self.prompt1 = PromptGenBlock(prompt_dim=dim_p,
        #                               prompt_len=5,
        #                               prompt_size = dim_p,
        #                               lin_dim = dim_p)
        
        # self.noise_level1 = TransformerBlock(dim=int(dim_p*2**1) ,
        #                                      num_heads=1, 
        #                                      ffn_expansion_factor=2.66, 
        #                                      bias=False, LayerNorm_type='WithBias')
        
        self.reduce_noise_level1 = nn.Conv2d(int(dim_p*2),int(dim_p*1),kernel_size=1,bias=False)



        self.layer_lka_2 = LKABlock(dim=out_dim)

        

        def init_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Conv2d):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

        init_weights(self)

    def forward(self, x1, x2=None):
        if x2 is not None:  # skip connection exist
            x2 = x2.contiguous()
            # b, c, h, w = x1.shape
            b2, h2, w2, c2 = x2.shape  # e.g: 1 28 28 320, 1 56 56 128
            x2 = x2.view(b2, -1, c2)  # e.g: 1 784 320, 1 3136 128

            x1_expand = self.x1_linear(x1)  # e.g: 1 784 256 --> 1 784 320, 1 3136 160 --> 1 3136 128

            x2_new = x2.view(x2.size(0), x2.size(2), x2.size(1) // w2, x2.size(1) // h2) # B, C, H, W
            

            x1_expand = x1_expand.view(x2.size(0), x2.size(2), x2.size(1) // w2, x2.size(1) // h2) # B, C, H, W

            # print(f'the x1_expand shape is: {x1_expand.shape}\n\t the x2_new shape is: {x2_new.shape}')

            #attn_gate = self.ag_attn(x=x2_new, g=x1_expand)  # B C H W

            cat_linear_x = x1_expand + x2_new  # B C H W
            cat_linear_x = cat_linear_x.permute(0, 2, 3, 1)  # B H W C
            cat_linear_x = self.ag_attn_norm(cat_linear_x)  # B H W C

            cat_linear_x = cat_linear_x.permute(0, 3, 1, 2).contiguous()  # B C H W

            tran_layer_1 = self.layer_lka_1(cat_linear_x)
            
            prompt_layer_1 = self.prompt1(tran_layer_1)
            
            cat_input_prompt = torch.cat([tran_layer_1, prompt_layer_1], dim= 1)
            cat_input_prompt = self.noise_level1(cat_input_prompt)
            refined_feature = self.reduce_noise_level1(cat_input_prompt)

            tran_layer_2 = self.layer_lka_2(refined_feature)

            tran_layer_2 = tran_layer_2.view(tran_layer_2.size(0), tran_layer_2.size(3) * tran_layer_2.size(2),
                                             tran_layer_2.size(1))
            if self.last_layer:
                out = self.last_layer(
                    self.layer_up(tran_layer_2).view(b2, 4 * h2, 4 * w2, -1).permute(0, 3, 1, 2))  # 1 9 224 224
            else:
                out = self.layer_up(tran_layer_2)  # 1 3136 160
        else:
            out = self.layer_up(x1)
        return out

In [13]:
input_0 = torch.randn(1, 96, 56, 56).cuda() # skip 3
input_1 = torch.randn(1, 192, 28, 28).cuda() # skip 2
input_2 = torch.randn(1, 384, 14, 14).cuda() # skip 1
input_3 = torch.randn(1, 768, 7, 7).cuda() # X



decoder_3 = MyDecoderLayerLKA_Prompt(input_size=(7,7),
                               in_out_chan=([768, 768])).cuda()

# decoder_2 = MyDecoderLayerLKA_Prompt(input_size=(14,14),
#                                in_out_chan=([384, 384])).cuda()
# decoder_1 = MyDecoderLayerLKA_Prompt(input_size=(28,28),
#                                in_out_chan=([192, 192])).cuda()
# decoder_0 = MyDecoderLayerLKA_Prompt(input_size=(56, 56), 
#                               in_out_chan=([96, 96]),
#                               is_last=True).cuda()

b, c, _, _ = input_3.shape

# output_3 = decoder_3(input_3.permute(0, 2, 3, 1).view(b,-1, c))
# output_2 = decoder_2(output_3, input_2.permute(0, 2 , 3, 1))
# output_1 = decoder_1(output_2, input_1.permute(0, 2, 3, 1))
# output_0 = decoder_0(output_1, input_0.permute(0, 2, 3, 1))

In [14]:
def calculate_params_in_millions(model):
  """Calculates the number of parameters in a PyTorch model in millions.

  Args:
    model: A PyTorch model.

  Returns:
    The number of parameters in millions.
  """
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  return num_params / 1e6

In [15]:
calculate_params_in_millions(decoder_3)

16.128

In [15]:
print(f"the output 3 shape: {output_3.shape}\nthe output 2 shape: {output_2.shape}")

the output 3 shape: torch.Size([1, 196, 384])
the output 2 shape: torch.Size([1, 784, 192])
