![Image](https://vitalab.github.io/blog/images/activationfunctions/sc01.jpeg)

![Image](https://www.researchgate.net/publication/373857926/figure/fig10/AS:11431281255205102@1719410280046/A-The-activation-functions-of-the-Swish-and-the-GELU-B-The-derivatives-of-Swish-and.tif)


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
t = torch.randn(1,2)

In [13]:
t.type_as(t)

tensor([[0.3643, 0.1344]])

In [None]:
# 함수 형태
def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)

In [None]:
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return torch.nn.functional.gelu(gate) * x

In [None]:
class StarReLU(nn.Module):
    def __init__(self, scale_value=1.0, bias_value=0.0,
                 scale_learnable=True, bias_learnable=True,
                 mode=None, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.relu = nn.ReLU(inplace=inplace)
        self.scale = nn.Parameter(scale_value * torch.ones(1),
            requires_grad=scale_learnable)
        self.bias = nn.Parameter(bias_value * torch.ones(1),
            requires_grad=bias_learnable)
        
    def forward(self, x):
        return self.scale * self.relu(x)**2 + self.bias

In [None]:
nn.Sequential(LayerNorm(10),
              nn.Linear(10, 20),
              GEGLU(),
              nn.Linear(10, 10)
              )(torch.randn(3, 33, 44, 10))

In [20]:
class SoftArgmax1D(nn.Module):
    def __init__(self, beta=1.0):
        """
        beta 큼 -> argmax에 가까워짐
        """
        super().__init__()
        self.beta = beta
    def forward(self, x):
        smax = F.softmax(self.beta * x, dim=1)
        pos = torch.arange(x.size(1)).type_as(x).to(x.device)
        return torch.sum(smax * pos, dim=1)

In [21]:
softargmax_layer = SoftArgmax1D(beta=10.0)

# (B, L) = (3, 5)
x = torch.tensor([
    [0.5, 1.2, 3.0, 2.0, 0.1],   # 첫 번째 배치
    [2.0, 2.1, 2.2, 1.0, 0.5],   # 두 번째 배치
    [0.0, 0.0, 0.0, 0.0, 1.0],   # 세 번째 배치
])

y = softargmax_layer(x)
print("SoftArgmax 결과:", y)

SoftArgmax 결과: tensor([2.0000, 1.5752, 3.9995])
