In [6]:
from typing import Any, Dict, Union, Tuple
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

#hyp
CONV_DIM = 32
FC_DIM = 512
WINDOW_WIDTH = 16
WINDOW_STRIDE =8

class ConvBlock(nn.Module):
  def __init__(self, input_channels : int, 
               output_channels : int, kernel_size = 3, stride = 1, padding = 1,):
    super().__init__()
    self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = kernel_size,
                          stride = stride, padding = padding)
    self.relu = nn.ReLU()

  def forward(self, x : torch.Tensor) -> torch.Tensor:
    """
    x : (Batch, channel, height ,width)
    return (batch, channel, height,width)
    """
    c = self.conv(x)
    r = self.relu(c)
    return r

class LineCNN(nn.Module):
  def __init__(self, data_config : Dict[str, Any], args : argparse.Namespace = None, ):
    super().__init__()
    self.data_config = data_config
    self.args = vars(args) if args is not None else {}
    self.num_classes = len(data_config["mapping"])
    self.output_length = data_config["output_dims"][0]

    _C, H, _W = data_config["input_dims"]
    conv_dim = self.args.get("conv_dim", CONV_DIM)
    fc_dim = self.args.get("fc_dim", FC_DIM)
    self.WW = self.args.get("window width", WINDOW_WIDTH)
    self.WS = self.args.get("window stride", WINDOW_STRIDE)
    self.limit_output_length = self.args.get('limit output length', False)

    self.convs = nn.Sequential(
        ConvBlock(1, conv_dim),
        ConvBlock(conv_dim , conv_dim),
        ConvBlock(conv_dim, conv_dim, stride = 2),
        ConvBlock(conv_dim, conv_dim),
        ConvBlock(conv_dim, conv_dim * 2 , stride = 2),
        ConvBlock(conv_dim * 2, conv_dim * 2),
        ConvBlock(conv_dim * 2, conv_dim * 2),
        ConvBlock(conv_dim * 2, conv_dim * 4, stride = 2),
        ConvBlock(conv_dim * 4, conv_dim * 4),
        ConvBlock(conv_dim * 4, fc_dim, kernel_size = (H // 8, self.WW // 8), stride = (H//8, self.WS // 8), padding = 0)
    )
    self.fc1 = nn.Linear(fc_dim, fc_dim)
    self.dropout = nn.Dropout(0.2)
    self.fc2 = nn.Linear(fc_dim, self.num_classes)

  def forward(self, x : torch.Tensor):
    """
    x : (batch, 1, h, w)
    return : (batch, classes, sequence len)
    """
    _B, _C, _H, _W = x.shape
    x = self.convs(x) # (B, fc_dim, 1, Sx)
    x = x.squeeze(2).permute(0, 2, 1) #(B, S, fc_dim)
    x = self.dropout(x)
    x = self.fc2(x)
    x = x.permute(0, 2, 1) #(B, C, S)
    if self.limit_output_length:
      x = x[:, :, : self.output_length]
    return x

  @staticmethod
  def add_to_argparse(parser):
      parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
      parser.add_argument("--fc_dim", type=int, default=FC_DIM)
      parser.add_argument(
          "--window_width",
          type=int,
          default=WINDOW_WIDTH,
          help="Width of the window that will slide over the input image.",
      )
      parser.add_argument(
          "--window_stride",
          type=int,
          default=WINDOW_STRIDE,
          help="Stride of the window that will slide over the input image.",
      )
      parser.add_argument("--limit_output_length", action="store_true", default=False)
      return parser

  

In [7]:
LSTM_DIM = 512
LSTM_LAYERS = 1
LSTM_DROPOUT = 0.2

class LineCNNLSTM(nn.Module):
  def __init__(self, data_config : Dict[str, Any], args : argparse.Namespace=None,):
    super().__init__()
    self.data_config = data_config
    self.args = vars(args) if args is not None else {}
    num_classes = len(data_config["mapping"])
    lstm_dim = self.args.get("lstm_dim", LSTM_DIM)
    lstm_layers = self.args.get("lstm_layers", LSTM_LAYERS)
    lstm_dropout = self.args.get("lstm_dropout", LSTM_DROPOUT)


    self.line_cnn = LineCNN(data_config=data_config, args=args) # output : (B, C, S)

    self.lstm = nn.LSTM(input_size = num_classes, hidden_size = lstm_dim, num_layers = lstm_layers, dropout = lstm_dropout, bidirectional = True,)
    self.fc = nn.Linear(lstm_dim, num_classes)


  def forward(self, x):
    """
    X : (B, H, W)
    output : (B, C, S)
    """
    x = self.line_cnn(x) #(B, C, S)
    B,_C, S = x.shape
    x.permute(2,0,1) #(S, B, C)

    x, _ = self.lstm(x) #(S, B, 2 * h(lstm dim))
    x = x.view(S, B, 2, -1).sum(dim = 2)
    x = self.fc(x) # (S,B,C)
    return x.permute(1, 2, 0)#(B, C, S)

  @staticmethod
  def add_to_argparse(parser):
      LineCNN.add_to_argparse(parser)
      parser.add_argument("--lstm_dim", type=int, default=LSTM_DIM)
      parser.add_argument("--lstm_layers", type=int, default=LSTM_LAYERS)
      parser.add_argument("--lstm_dropout", type=float, default=LSTM_DROPOUT)
      return parser
