In [28]:
import torch 
import torch.nn as nn
from einops import rearrange

In [29]:
class S_CEMBlock(nn.Module):
    def __init__(self, c, DW_Expand=2,num_heads=3, FFN_Expand=2, drop_out_rate=0.):
        """S-CEM模块

        Args:
            c (_type_): _description_
            DW_Expand (int, optional): _description_. Defaults to 2.
            num_heads (int, optional): _description_. Defaults to 3.
            FFN_Expand (int, optional): _description_. Defaults to 2.
            drop_out_rate (_type_, optional): _description_. Defaults to 0..
        """
        super().__init__()
        self.num_heads = num_heads

        self.qkv = nn.Conv2d(c, c * 3, kernel_size=1)
        self.qkv_dwconv = nn.Conv2d(c * 3, c * 3, kernel_size=3, stride=1, padding=1, groups=c * 3)

        self.project_out = nn.Conv2d(c, c, kernel_size=1)
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.project_out2 = nn.Conv2d(c, c, kernel_size=1)
        self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1,
                               bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
                               groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.beta2 = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.relu=nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, inp):
        x = inp
        x = self.norm1(x)  # layernorm
        b, c, h, w = x.shape
        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)  # 沿1轴切分为3块

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)  # 通道注意力
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        qs = q.clone().permute(0, 1, 3, 2)  # 空间注意力
        ks = k.clone().permute(0, 1, 3, 2)
        vs = v.clone().permute(0, 1, 3, 2)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn=self.relu(attn)
        attn = self.softmax(attn)

        outc = (attn @ v)  # 通道注意力的输出

        qs = torch.nn.functional.normalize(qs, dim=-1)
        ks = torch.nn.functional.normalize(ks, dim=-1)

        attns = (qs @ ks.transpose(-2, -1)) * self.temperature2
        attns=self.relu(attns)
        attns = self.softmax(attns)
        outs = (attns @ vs)
        outs = outs.permute(0, 1, 3, 2)  # 空间注意力的输出

        outc = rearrange(outc, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        outs = rearrange(outs, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        xc = self.project_out(outc)  # 
        xc = self.dropout1(xc)
        xs = self.project_out2(outs)
        xs = self.dropout1(xs)

        y = inp + xc * self.beta+ xs * self.beta2  # 加和

        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        x = self.conv5(x)

        x = self.dropout2(x)

        return y + x * self.gamma