In [1]:
import os
import sys
import importnb
from torch import nn
import torch
import numpy as np

In [16]:
notebook_path = os.getcwd()
parent_dir = os.path.dirname(notebook_path)
sys.path.append(parent_dir)
with __import__("importnb").Notebook():
    from utils.tools import MultiHeadAttention
    from utils.tools import AddPositionalEncoding
    from utils.tools import TransformerFFN
    from utils.tools import Patch
    from utils.tools import MLPHead

In [17]:
class TransformerEncoderLayer(nn.Module):

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_head: int,
        dropout_rate: float,
        layer_norm_eps: float,
    ) -> None:
        super().__init__()
        # layerの宣言
        self.mha = MultiHeadAttention(num_head, d_model)
        self.layernorm_mha = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout_mha = nn.Dropout(dropout_rate)

        self.ffn = TransformerFFN(d_model, d_ff)
        self.dropout_ffn = nn.Dropout(dropout_rate)
        self.layernorm_ffn = nn.LayerNorm(d_model, eps=layer_norm_eps)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:

        # attention層を通す
        # add+layernorm
        x = self.layernorm_mha(self.__get_mha_output(x, mask) + x)

        # FFN層を通す
        # add+layernorm
        x = self.layernorm_ffn(self.__get_ffn_output(x) + x)

        return x

    def __get_mha_output(
        self, x: torch.Tensor, mask: torch.Tensor = None
    ) -> torch.Tensor:
        x = self.mha(x, x, x, mask)
        x = self.dropout_mha(x)
        return x

    def __get_ffn_output(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        x = self.ffn(x)
        x = self.dropout_ffn(x)
        return x

In [18]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_head: int,
        dropout_rate: float,
        layer_norm_eps: float,
        # 変更点
        patch_num: int,
        patch_dim: int,
        N: int,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        super().__init__()
        # cls_tokenの付加
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        # InputEmbedding層の定義
        self.embedding = nn.Linear(patch_dim, d_model)
        # positionalencoding層の定義
        self.pos = AddPositionalEncoding(d_model, patch_num + 1, device)
        # encoderlayer層の定義
        self.encoder_layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    d_model, d_ff, num_head, dropout_rate, layer_norm_eps
                )
                for _ in range(N)
            ]
        )

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # テンソルを表す変数（例：input_tensor）があると仮定
        batch_size = x.size(0)
        x = self.embedding(x)
        cls_tokens = self.cls_token.repeat(batch_size, 1, 1)
        x = torch.concat([cls_tokens, x], dim=1)
        x = self.pos(x)
        for layer in self.encoder_layers:
            x = layer(x, mask)
        return x

In [26]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_head: int,
        patch_size: int,
        patch_num: int,
        patch_dim: int,
        out_dim: int,
        N: int,
        dropout_rate: float = 0.1,
        layer_norm_eps: float = 1e-5,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()
        self.patch_and_flatten = Patch(
            patch_size=patch_size,
        )
        self.encoder = TransformerEncoder(
            d_model=d_model,
            d_ff=d_ff,
            num_head=num_head,
            dropout_rate=dropout_rate,
            layer_norm_eps=layer_norm_eps,
            patch_num=patch_num,
            patch_dim=patch_dim,
            N=N,
            device=device,
        )
        self.mlp_head = MLPHead(d_model=d_model, out_dim=out_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.patch_and_flatten(x)
        x = self.encoder(x)
        x = self.mlp_head(x)
        x = self.softmax(x)
        return x

In [27]:
x = torch.randn(size=(2, 3, 12, 12))

In [28]:
vit = VisionTransformer(
    d_model=512,
    d_ff=1024,
    num_head=8,
    patch_size=4,
    patch_num=9,
    patch_dim=48,
    out_dim=10,
    N=6,
)

In [29]:
x = vit(x)

torch.Size([2, 10])