In [None]:
import torch
import torch.nn as nn
from spikingjelly.clock_driven.neuron import MultiStepLIFNode # MultiStepLIFNode: 다단계 LIF 뉴런
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ # to_2tuple: 2-tuple로 변환하는 함수
from timm.models.registry import register_model # register_model: 모델을 등록하는 함수
from timm.models.vision_transformer import _cfg, Mlp, PatchEmbed, _create_vision_transformer # _cfg: 모델의 기본 설정을 반환하는 함수
import torch.nn.functional as F
from functools import partial # partial: 함수의 인자를 고정시키는 함수

In [None]:
class Spike_Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim # dim 저장
        self.num_heads = num_heads # num_heads 저장
        self.scale = 0.125 # scale 저장
        self.q_linear = nn.Linear(dim, dim) # q_linear: query를 위한 fully-connected layer
        self.q_bn = nn.BatchNorm1d(dim) # q_bn: query를 위한 batch normalization
        self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') # q_lif: query를 위한 LIF 뉴런

        self.k_linear = nn.Linear(dim, dim) # k_linear: key를 위한 fully-connected layer
        self.k_bn = nn.BatchNorm1d(dim) # k_bn: key를 위한 batch normalization
        self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') # k_lif: key를 위한 LIF 뉴런

        self.v_linear = nn.Linear(dim, dim) # v_linear: value를 위한 fully-connected layer
        self.v_bn = nn.BatchNorm1d(dim) # v_bn: value를 위한 batch normalization
        self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') # v_lif: value를 위한 LIF 뉴런
        self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy') # attn_lif: attention을 위한 LIF 뉴런

        self.proj_linear = nn.Linear(dim, dim) # proj_linear: projection을 위한 fully-connected layer
        self.proj_bn = nn.BatchNorm1d(dim) # proj_bn: projection을 위한 batch normalization
        self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') # proj_lif: projection을 위한 LIF 뉴런

    def forward(self, x): 
        T,B,N,C = x.shape

        x_for_qkv = x.flatten(0, 1)  # x_for_qkv의 shape: (T * B, N, C)
        q_linear_out = self.q_linear(x_for_qkv) # q_linear_out의 shape: (T * B, N, C)
        q_linear_out = self.q_bn(q_linear_out. transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous() # q_linear_out의 shape: (T, B, N, C)
        q_linear_out = self.q_lif(q_linear_out) # q_linear_out의 shape: (T, B, N, C)
        q = q_linear_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() # q의 shape: (T, B, num_heads, N, C//num_heads)

        k_linear_out = self.k_linear(x_for_qkv) # k_linear_out의 shape: (T * B, N, C)
        k_linear_out = self.k_bn(k_linear_out. transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous() # k_linear_out의 shape: (T, B, N, C)
        k_linear_out = self.k_lif(k_linear_out) # k_linear_out의 shape: (T, B, N, C)
        k = k_linear_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() # k의 shape: (T, B, num_heads, N, C//num_heads)

        v_linear_out = self.v_linear(x_for_qkv) # v_linear_out의 shape: (T * B, N, C)
        v_linear_out = self.v_bn(v_linear_out. transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C).contiguous() # v_linear_out의 shape: (T, B, N, C)
        v_linear_out = self.v_lif(v_linear_out) # v_linear_out의 shape: (T, B, N, C)
        v = v_linear_out.reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() # v의 shape: (T, B, num_heads, N, C//num_heads)

        attn = (q @ k.transpose(-2, -1)) * self.scale # attn의 shape: (T, B, num_heads, N, N)
        x = attn @ v # x의 shape: (T, B, num_heads, N, C//num_heads)
        x = x.transpose(2, 3).reshape(T, B, N, C).contiguous() # x의 shape: (T, B, N, num_heads * C//num_heads)
        x = self.attn_lif(x) # x의 shape: (T, B, N, num_heads * C//num_heads)
        x = x.flatten(0, 1) # x의 shape: (T * B, N, num_heads * C//num_heads)
        x = self.proj_lif(self.proj_bn(self.proj_linear(x).transpose(-1, -2)).transpose(-1, -2).reshape(T, B, N, C)) # x의 shape: (T, B, N, C)
        return x # x의 shape: (T, B, N, C)