In [1]:
from typing import Union
import math

import torch
import torch.nn as nn


In [2]:
class TripleSigmoid(nn.Module):
    """Triple-Sigmoid Activation Function.
    Paper: https://ieeexplore.ieee.org/document/9833503/
    """
    def __init__(
        self,
        w1: float = 0.005,
        w2: float = 0.1,
        w3: float = 0.001,
        alpha: float = 0,
        beta: float = 500,
        gamma: float = 0,
        delta: float = 1.5,
        dtype = torch.float32,
        device: Union[str, torch.device] = 'cpu',
    ):
        """Initialize Triple-Sigmoid function.

        Args:
            w1 (float, optional): w1. Defaults to 0.005.
            w2 (float, optional): w2. Defaults to 0.1.
            w3 (float, optional): w3. Defaults to 0.001.
            alpha (float, optional): alpha. Defaults to 0.
            beta (float, optional): beta. Defaults to 500.
            gamma (float, optional): gamma. Defaults to 0.
            delta (float, optional): delta. Defaults to 1.5.
            dtype (_type_, optional): dtype. Defaults to torch.float32.
            device (Union[str, torch.device], optional): device. Defaults to 'cpu'.
        """
        super().__init__()

        self.dtype = dtype
        self.device = torch.device(device)

        self.w1 = torch.tensor(w1, dtype=self.dtype, device=self.device)
        self.w2 = torch.tensor(w2, dtype=self.dtype, device=self.device)
        self.w3 = torch.tensor(w3, dtype=self.dtype, device=self.device)
        self.alpha = torch.tensor(alpha, dtype=self.dtype, device=self.device)
        self.beta = torch.tensor(beta, dtype=self.dtype, device=self.device)

        self.b0 = torch.tensor(delta, dtype=self.dtype, device=self.device)
        self.b1 = torch.tensor((w2 * gamma + (w1 - w2) * alpha) / w1, dtype=self.dtype, device=self.device)
        self.b2 = torch.tensor(gamma)
        self.b3 = beta

        _temp = math.exp(-delta)
        self.t_beta = torch.tensor(1 / (1 + math.exp(-w2 * (beta - gamma) - delta)) - _temp / (1 + _temp), dtype=self.dtype, device=self.device)

    def _calc_part(self, h: torch.Tensor, w: float, b: float) -> torch.Tensor:
        """Calculate `e^(-w * (h - b) - b0)`.

        Args:
            h (torch.Tensor): h
            w (float): w
            b (float): b

        Returns:
            torch.Tensor: Output of `e^(-w * (h - b) - b0)`
        """
        return torch.exp(-w * (h - b) - self.b0)

    def eq_1_1(self, h: torch.Tensor) -> torch.Tensor:
        """Calculate equation 1 when h < alpha.

        Args:
            h (torch.Tensor): h

        Returns:
            torch.Tensor: Output of t(h)
        """
        return 1 / (1 + self._calc_part(h, self.w1, self.b1))

    def eq_1_2(self, h: torch.Tensor) -> torch.Tensor:
        """Calculate equation 2 when alpha <= h < beta.

        Args:
            h (torch.Tensor): h

        Returns:
            torch.Tensor: Output of t(h)
        """
        return 1 / (1 + self._calc_part(h, self.w2, self.b2))

    def eq_1_3(self, h: torch.Tensor) -> torch.Tensor:
        """Calculate equation 3 when beta <= h.

        Args:
            h (torch.Tensor): h

        Returns:
            torch.Tensor: Output of t(h)
        """
        _temp = self._calc_part(h, self.w3, self.b3)
        return self.t_beta + _temp / (1 + _temp)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """Calculate Triple-Sigmoid.

        Args:
            h (torch.Tensor): Input

        Returns:
            torch.Tensor: Output of Triple-Sigmoid
        """
        h1 = torch.where(h < self.alpha, self.eq_1_1(h), 0.)
        h2 = torch.where(torch.logical_and(self.alpha <= h, h < self.beta), self.eq_1_2(h), 0.)
        h3 = torch.where(self.beta <= h, self.eq_1_3(h), 0.)

        return h1 + h2 + h3

In [3]:
class Net1(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        activation_layer: str,
        input_size: int,
        is_mnist: bool,
        device
    ):
        super().__init__()

        assert input_size == 28 or input_size == 32, \
            f'Input size {input_size} is not supported. Acceptable values: 28, 32.'

        assert activation_layer in {'softmax', 'sigmoid', 'triple-sigmoid'}, \
            f'Only supported the following activation layers: "softmax", "sigmoid", "triple-sigmoid".'

        self.block1 = Net1.build_block(in_channels, 32, 3, is_mnist)
        self.block2 = Net1.build_block(32, 64, 3, is_mnist)
        self.block3 = Net1.build_block(64, 128, 3, is_mnist)
        self.block4 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1152 if input_size == 28 else 2048, 128),
            nn.LeakyReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(),
            nn.Linear(128, num_classes)
        )

        if activation_layer == 'softmax':
            self.act = nn.Softmax()
        elif activation_layer == 'sigmoid':
            self.act = nn.Sigmoid()
        else:
            self.act = TripleSigmoid(device=device)

    def forward(self, x: torch.Tensor):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return self.act(x)

    @staticmethod
    def build_block(
        in_channels: int,
        hidden_channels: int,
        kernel_size: int,
        exclude_batch_norm: bool
    ) -> nn.Module:
        """Build block for Net 1.

        Args:
            in_channels (int): Number of input channels
            hidden_channels (int): Number of hidden channels
            kernel_size (int): Convolution kernel size
            exclude_batch_norm (bool): Set to True to exclude batch normalization layer in block

        Returns:
            nn.Module: Block
        """
        if exclude_batch_norm:
            return nn.Sequential(
                nn.Conv2d(in_channels, hidden_channels, kernel_size, padding='same'),
                nn.LeakyReLU(),
                nn.Conv2d(hidden_channels, hidden_channels, kernel_size, padding='same'),
                nn.LeakyReLU(),
                nn.MaxPool2d(2),
                nn.Dropout()
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_channels, hidden_channels, kernel_size, padding='same'),
                nn.LeakyReLU(),
                nn.BatchNorm2d(hidden_channels),
                nn.Conv2d(hidden_channels, hidden_channels, kernel_size, padding='same'),
                nn.LeakyReLU(),
                nn.BatchNorm2d(hidden_channels),
                nn.MaxPool2d(2),
                nn.Dropout()
            )

In [4]:
model1 = Net1(3, 1000, 'triple-sigmoid', 32, False, 'cpu')
model2 = Net1(3, 1000, 'triple-sigmoid', 32, False, 'cuda').to('cuda')
model3 = Net1(3, 1000, 'triple-sigmoid', 32, False, 'cpu').to('cuda')

In [5]:
x = torch.rand((4, 3, 32, 32))

In [6]:
%%timeit -r 10 -n 1000
with torch.no_grad():
    y = model1(x)

9.14 ms ± 676 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)


In [7]:
x = x.to('cuda')

In [8]:
%%timeit -r 10 -n 1000
with torch.no_grad():
    y = model2(x)

2.08 ms ± 480 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)


In [9]:
%%timeit -r 10 -n 1000
with torch.no_grad():
    y = model3(x)

1.94 ms ± 153 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
