In [29]:
import torch.nn as nn
import torch
from timm.models.layers import DropPath, trunc_normal_
class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        x = self.dwconv(x)
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)
 

    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LKA(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
        self.conv1 = nn.Conv2d(dim, dim, 1)


    def forward(self, x):
        u = x.clone()        
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)

        return u * attn


class Attention(nn.Module):
    def __init__(self, dim,ratio=1):
        super().__init__()

        self.proj_1 = nn.Conv2d(dim, dim//ratio, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = LKA(dim//ratio)
        self.proj_2 = nn.Conv2d(dim//ratio, dim, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x


class LKA_Block(nn.Module):
    def __init__(self, in_channel,out_channel, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.PReLU(),
            nn.Conv2d(out_channel,out_channel,3,padding=1),
            nn.BatchNorm2d(out_channel),
            nn.PReLU()               
            )

        self.attn = Attention(out_channel)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = nn.BatchNorm2d(out_channel)
        mlp_hidden_dim = int(out_channel * mlp_ratio)
        self.mlp = Mlp(in_features=out_channel, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        layer_scale_init_value = 1e-2            
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((out_channel)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((out_channel)), requires_grad=True)



    def forward(self, x):
        x = self.double_conv(x)
        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(x))
        x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
        return x
class PCA(nn.Module):
    def __init__(self, input_channels):
        super(PCA, self).__init__()
        self.GAP = nn.AdaptiveAvgPool2d(1)
        
        # 서로 다른 커널 크기를 가진 1D 컨볼루션 레이어들
        self.conv1d_1 = nn.Conv1d(input_channels, out_channels=input_channels, kernel_size=1)
        self.conv1d_3 = nn.Conv1d(input_channels, out_channels=input_channels, kernel_size=3, padding=1)
        self.conv1d_5 = nn.Conv1d(input_channels, out_channels=input_channels, kernel_size=5, padding=2)
    
    def forward(self, x):
        u = x.clone()
        x = self.GAP(x)
        x = x.squeeze(-1)  # 2D 형태로 변환

        out1 = self.conv1d_1(x)
        out3 = self.conv1d_3(x)
        out5 = self.conv1d_5(x)

        outs = out1 + out3 + out5
        result = u * outs.unsqueeze(-1)

        return result
    

class CA(nn.Module):
    def __init__(self,dim,ratio=1):
        super(CA, self).__init__()
        self.avgpool_x = nn.AdaptiveAvgPool2d((1, None))  # X 축에 대한 pooling
        self.avgpool_y = nn.AdaptiveAvgPool2d((None, 1))  # Y 축에 대한 pooling

        self.conv = nn.Conv2d(dim,dim//ratio,1)
        self.bn = nn.BatchNorm2d(dim//ratio)
        self.siLU = nn.SiLU(inplace=True)

        self.conv_1 = nn.Conv2d(dim//ratio,dim,1)
        self.conv_2 = nn.Conv2d(dim//ratio,dim,1)

        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        u = x.clone()
        x_avg = self.avgpool_x(x)  # shape: (batch_size, 32, 1, 256)
        y_avg = self.avgpool_y(x).permute(0,1,3,2)  # shape: (batch_size, 32, 1, 256)
        # 두 결과를 높이 방향으로 연결
        combined = torch.cat([x_avg, y_avg], dim=3).permute(0,1,3,2)  # shape: (batch_size, 32, 512, 1)

        combined = self.siLU(self.bn(self.conv(combined)))
        
        split_1, split_2 = torch.split(combined, split_size_or_sections=combined.shape[2]//2, dim=2) 
        split_1 = self.sigmoid(self.conv_1(split_1)) # batch, 32, 256, 1
        split_2 = self.sigmoid(self.conv_2(split_2.permute(0,1,3,2))) # batch, 32, 1, 256

        result = u * split_1 * split_2
        return result
    
class PCCA(nn.Module):
    def __init__(self,dim):
        super(PCCA, self).__init__()
        self.PCA = PCA(dim)
        self.CA = CA(dim)

    def forward(self,x):
        pca = self.PCA(x)
        ca = self.CA(x)

        return pca+ca
    


In [30]:

import torch.nn as nn

class AP(nn.Module):
    def __init__(self,in_channel,out_channel,first=False):
        super(AP,self).__init__()

        self.lka_block = LKA_Block(in_channel, out_channel)
        self.pcca = PCCA(out_channel)
        self.pool2d = nn.MaxPool2d(2)
        self.first = first
    def forward(self,x):
        if not self.first:
            x = self.pool2d(x)

        x = self.lka_block(x)
        x = self.pcca(x)


        return x


class Upsample(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Upsample, self).__init__()
        
        # 전치 합성곱 (Transposed Convolution)
        self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
        self.ap = AP(in_channel,out_channel,first=True)
        
        # 활성화 함수
        self.prelu = nn.PReLU()

    def forward(self, x, skip_connection):

        
        # 전치 합성곱을 사용하여 업샘플링
        x = self.deconv(x)
        x = torch.cat([skip_connection, x], dim=1)
        x = self.ap(x)
    
        return x
    

class Out(nn.Module):
    def __init__(self, in_channel, num_clases):
        super(Out, self).__init__()
        self.num_classes = num_clases

        self.deconv = nn.ConvTranspose2d(in_channel, in_channel//2, kernel_size=2, stride=2)
        self.ap = AP(in_channel,self.num_classes,first=True)
        
        # 활성화 함수
        self.prelu = nn.PReLU()

    def forward(self, x, skip_connection):

        
        # 전치 합성곱을 사용하여 업샘플링
        x = self.deconv(x)
        x = torch.cat([skip_connection, x], dim=1)
        x = self.ap(x)
    
        return x




class AP_UNet(nn.Module):
    def __init__(self,in_channel,num_classes):
        super(AP_UNet, self).__init__()
        features = [32, 64, 128, 256, 512]
        self.in_channel = in_channel
        self.num_classes = num_classes

        self.enc1 = AP(self.in_channel,features[0],first=True) # 32x512x512
        self.enc2 = AP(features[0],features[1]) # 64x256x256
        self.enc3 = AP(features[1],features[2]) # 128x128x128
        self.enc4 = AP(features[2],features[3]) # 256x64x64
        self.enc5 = AP(features[3],features[4]) # 512x32x32

        self.up1 = Upsample(features[4],features[3])
        self.up2 = Upsample(features[3],features[2])
        self.up3 = Upsample(features[2],features[1])

        self.out = Out(features[1],self.num_classes)

    def forward(self,x):
        
        x1 = self.enc1(x) 
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)
        x5 = self.enc5(x4)
        

        x6 = self.up1(x5,x4)
        x7 = self.up2(x6,x3)
        x8 = self.up3(x7,x2)

        out = self.out(x8,x1)

        return out




In [31]:
import torch

sample = torch.randn((1,1,512,512))

ap_unet = AP_UNet(1,1)

ap_unet(sample).shape

torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 128, 128]) torch.Size([1, 64, 256, 256])


torch.Size([1, 1, 512, 512])