In [5]:
import math
import torch
import torch.nn as nn
from torch import Tensor

class PositionalEncoding(nn.Module):
  def __init__(self, d_model : int, dropout : float = 0.1, max_len : int = 5000) -> None:
    super().__init__()

    self.dropout = nn.Dropout(p = dropout)
    pe = self.make_pe(d_model = d_model, max_len = max_len)
    self.register_buffer("pe", pe)

  @staticmethod
  def make_pe(d_model : int, max_len : int) -> torch.Tensor:
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(1)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # x.shape = (S, B, d_model)
    assert x.shape[2] == self.pe.shape[2]  # type: ignore
    x = x + self.pe[: x.size(0)]  # type: ignore
    return self.dropout(x)

def generate_square_subsequent_mask(size : int) -> torch.Tensor:
  mask = (torch.triu(torch.ones(size, size))==1).transpose(0, 1) #역삼각행렬을 전치
  mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
  return mask

In [10]:
from typing import Any, Dict, Union, Tuple
import argparse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# Common type hints
Param2D = Union[int, Tuple[int, int]]

CONV_DIM = 32
FC_DIM = 512
WINDOW_WIDTH = 16
WINDOW_STRIDE = 8


class ConvBlock(nn.Module):
    """
    Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
    """

    def __init__(
        self,
        input_channels: int,
        output_channels: int,
        kernel_size: Param2D = 3,
        stride: Param2D = 1,
        padding: Param2D = 1,
    ) -> None:
        super().__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x
            of dimensions (B, C, H, W)
        Returns
        -------
        torch.Tensor
            of dimensions (B, C, H, W)
        """
        c = self.conv(x)
        r = self.relu(c)
        return r


class LineCNN(nn.Module):
    """
    Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits
    """

    def __init__(
        self,
        data_config: Dict[str, Any],
        args: argparse.Namespace = None,
    ) -> None:
        super().__init__()
        self.data_config = data_config
        self.args = vars(args) if args is not None else {}
        self.num_classes = len(data_config["mapping"])
        self.output_length = data_config["output_dims"][0]

        _C, H, _W = data_config["input_dims"]
        conv_dim = self.args.get("conv_dim", CONV_DIM)
        fc_dim = self.args.get("fc_dim", FC_DIM)
        self.WW = self.args.get("window_width", WINDOW_WIDTH)
        self.WS = self.args.get("window_stride", WINDOW_STRIDE)
        self.limit_output_length = self.args.get("limit_output_length", False)

        # Input is (1, H, W)
        self.convs = nn.Sequential(
            ConvBlock(1, conv_dim),
            ConvBlock(conv_dim, conv_dim),
            ConvBlock(conv_dim, conv_dim, stride=2),
            ConvBlock(conv_dim, conv_dim),
            ConvBlock(conv_dim, conv_dim * 2, stride=2),
            ConvBlock(conv_dim * 2, conv_dim * 2),
            ConvBlock(conv_dim * 2, conv_dim * 4, stride=2),
            ConvBlock(conv_dim * 4, conv_dim * 4),
            ConvBlock(
                conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0
            ),
        )
        self.fc1 = nn.Linear(fc_dim, fc_dim)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(fc_dim, self.num_classes)

        self._init_weights()

    def _init_weights(self):
        """
        Initialize weights in a better way than default.
        See https://github.com/pytorch/pytorch/issues/18182
        """
        for m in self.modules():
            if type(m) in {
                nn.Conv2d,
                nn.Conv3d,
                nn.ConvTranspose2d,
                nn.ConvTranspose3d,
                nn.Linear,
            }:
                nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(  # pylint: disable=protected-access
                        m.weight.data
                    )
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x
            (B, 1, H, W) input image
        Returns
        -------
        torch.Tensor
            (B, C, S) logits, where S is the length of the sequence and C is the number of classes
            S can be computed from W and self.window_width
            C is self.num_classes
        """
        _B, _C, _H, _W = x.shape
        x = self.convs(x)  # (B, FC_DIM, 1, Sx)
        x = x.squeeze(2).permute(0, 2, 1)  # (B, S, FC_DIM)
        x = F.relu(self.fc1(x))  # -> (B, S, FC_DIM)
        x = self.dropout(x)
        x = self.fc2(x)  # (B, S, C)
        x = x.permute(0, 2, 1)  # -> (B, C, S)
        if self.limit_output_length:
            x = x[:, :, : self.output_length]
        return x

    @staticmethod
    def add_to_argparse(parser):
        parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
        parser.add_argument("--fc_dim", type=int, default=FC_DIM)
        parser.add_argument(
            "--window_width",
            type=int,
            default=WINDOW_WIDTH,
            help="Width of the window that will slide over the input image.",
        )
        parser.add_argument(
            "--window_stride",
            type=int,
            default=WINDOW_STRIDE,
            help="Stride of the window that will slide over the input image.",
        )
        parser.add_argument("--limit_output_length", action="store_true", default=False)
        return parser

In [13]:
TF_DIM = 256
TF_FC_DIM = 256
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4

class LineCNNTransformer(nn.Module):
  def __init__(self, data_config:Dict[str, Any], args : argparse.Namespace = None,) ->None:
    super().__init__()
    self.data_config = data_config
    self.input_dims = data_config["input_dims"]
    self.num_classes = len(data_config["mapping"])
    inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
    self.start_token = inverse_mapping["<S>"]
    self.end_token = inverse_mapping["<E>"]
    self.padding_token = inverse_mapping["<P>"]
    self.max_output_length = data_config["output_dims"][0]
    self.args = vars(args) if args is not None else {}


    self.dim = self.args.get("tf_dim", TF_DIM)
    tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
    tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
    tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
    tf_layers = self.args.get("tf_layers", TF_LAYERS)

    data_config_for_line_cnn = {**data_config}
    data_config_for_line_cnn["mapping"] = list(range(self.dim))
    self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args)

    # LineCNN outputs (B, E, S) log probs, with E == dim

    self.embedding = nn.Embedding(self.num_classes, self.dim)
    self.fc = nn.Linear(self.dim, self.num_classes)

    self.pos_encoder = PositionalEncoding(d_model=self.dim)

    self.y_mask = generate_square_subsequent_mask(self.max_output_length)

    self.transformer_decoder = nn.TransformerDecoder(
        nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
        num_layers=tf_layers,
    )

    self.init_weights()  # This is empirically important

  def init_weights(self):
      initrange = 0.1
      self.embedding.weight.data.uniform_(-initrange, initrange)
      self.fc.bias.data.zero_()
      self.fc.weight.data.uniform_(-initrange, initrange)

  def encode(self, x : torch.Tensor) -> torch.Tensor:
    """
    x : (B, W, H)
    returns : (Sx, B, E)
    """
    x = self.line_cnn(x) # image 입력
    x = x * math.sqrt(self.dim) 
    x = x.permute(2,0,1)
    x = self.pos_encoder(x)
    return x

  def decode(self, x, y):
    """
    x (B, H, W)
    y (B, Sy)
    returns : (Sy, B, C) # C : num classes
    """
    y_padding_mask = y == self.padding_token
    y = y.permute(1, 0)
    y = self.embedding(y) * math.sqrt(self.dim)
    y = self.pos_encoder(y)
    Sy = y.shape[0]
    y_mask = self.y_mask[:Sy, :Sy].type_as(x)
    output = self.transformer_decoder(
            tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
        )  # (Sy, B, E)
    output = self.fc(output)  # (Sy, B, C)
    return output

  def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Parameters
    ----------
    x
        (B, H, W) image
    y
        (B, Sy) with elements in [0, C-1] where C is num_classes
    Returns
    -------
    torch.Tensor
        (B, C, Sy) logits
    """
    x = self.encode(x)  # (Sx, B, E)
    output = self.decode(x, y)  # (Sy, B, C)
    return output.permute(1, 2, 0)  # (B, C, Sy)

  def predict(self, x: torch.Tensor) -> torch.Tensor:
      """
      Parameters
      ----------
      x
          (B, H, W) image
      Returns
      -------
      torch.Tensor
          (B, Sy) with elements in [0, C-1] where C is num_classes
      """
      B = x.shape[0]
      S = self.max_output_length #문장 길이
      x = self.encode(x)  # (Sx, B, E)

      output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long()  # (B, S)
      output_tokens[:, 0] = self.start_token  # Set start token
      for Sy in range(1, S):
        y = output_tokens[:, :Sy]  # (B, Sy)
        output = self.decode(x, y)  # (Sy, B, C)
        output = torch.argmax(output, dim=-1)  # (Sy, B)
        output_tokens[:, Sy] = output[-1:]  # Set the last output token

      # Set all tokens after end token to be padding
      for Sy in range(1, S):
        ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
        output_tokens[ind, Sy] = self.padding_token

      return output_tokens  # (B, Sy)

  @staticmethod
  def add_to_argparse(parser):
    LineCNN.add_to_argparse(parser)
    parser.add_argument("--tf_dim", type=int, default=TF_DIM)
    parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM)
    parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
    parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
    parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
    return parser
