In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, IterableDataset

# from model import SwipeCurveEncoderTransformer

In [192]:
from typing import List, Optional

class CharLevelTokenizerv1:
    def __init__(self, vocab_path):
        self.char_to_idx = {}
        self.idx_to_char = {}
        self.max_word_len = None  # is set in _build_vocab
        self._build_vocab(vocab_path)

    def _build_vocab(self, vocab_path):
        self.max_word_len = 0
        unique_chars = set({"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3})
        with open(vocab_path, "r", encoding="utf-8") as f:
            vocab = f.read().split("\n")
            for word in vocab:
                self.max_word_len = max(self.max_word_len, len(word) + 2)
                for char in word:
                    unique_chars.add(char)
        self.char_to_idx = {char: idx for idx, char in enumerate(unique_chars)}
        self.idx_to_char = {idx: char for idx, char in enumerate(unique_chars)}

    def _tokenize_word(self, word):
        """
        Tokenizes a word into a list of integers.
        """
        tokenized_word = []
        tokenized_word.append(self.char_to_idx["<sos>"])
        for char in word:
            tokenized_word.append(self.char_to_idx[char])
        tokenized_word.append(self.char_to_idx["<eos>"])
        return tokenized_word
    
    def _pad_word(self, word):
        """
        Pads a word to the max_word_len.
        """
        return word + [self.char_to_idx["<pad>"]] * (self.max_word_len - len(word))
    
    def tokenize(self, word):
        """
        Tokenizes a word and pads it to the max_word_len.
        """
        token_seq = torch.tensor(self._pad_word(self._tokenize_word(word)))
        mask = torch.zeros(self.max_word_len, dtype=torch.bool)
        mask[:len(word)+2] = True
        return token_seq, mask


Я вижу два решения:

Для простоты я бы сделал 2 класса датасета
Если нужно кодировать лишь последовательность букв, он и хранит последовательности букв сразу и не хранит коордианты

В обоих случаях декодер оперирует эмбеддингами букв текста

### 1. На вход энкодера x, y, t, dx/dt, dy/st, x'', y'', keybard_key_embedding
**Что делать, если ближайшая клавиша неалфавитная (пунктуация, клавиши-действия)?**
Добавлю для всех неалфавитных клавиш один специальный токен

**Где происходит инициализация токенизатора?**
я бы вынес токенезатор вне датасета и передавал бы его в конструктор датасета.


для каждой раскладки свои instance'ы датасета и модели.



### 2. На вход энкодера последовательность клавиш клавиатуры
Если ближайшая клавиша неалфавитная **пропускать**

**Где происходит инициализация токенизатора?**


один instance датасета и одна модель для всех раскладок.




Реализовывать ли для каждого варианта отдельный токенизатор:

У нас может быть различное количество токенов: в некотоорых раскладках отсутствует символ "ъ", например

Когда датасет содержит лишь одну раскладку, токенизатор должен учесть символы из одной раскладки. Когда датасет содержит несколько раскладок, токенизатор должен учесть символы из всех раскладок.

Кажется, что варьируется только наличие 'ъ' и 'ё'. Во-первых, не ясно нужны ли эти символы. Есть желание заменять 'ё' на 'е', а 'ъ' на 'ь'. 

In [193]:
class KeyboardTokenizerv1:
    
    i2t = ['а', 'б', 'в', 'г', 'д', 'е', 'ë', 'ж', 'з', 'и', 'й',
           'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'ф',
           'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я',
           '-', '<unk>', '<pad>']
    
    t2i = {t: i for i, t in enumerate(i2t)}

    def get_token(self, char):
        return self.t2i.get(char, self.t2i['<unk>'])

In [194]:
import json
from typing import Optional, List, Tuple, Dict
import array

from torch.utils.data import Dataset
from tqdm import tqdm


class NeuroSwipeDatasetv1(Dataset):
    """
    Dataset class for NeuroSwipe dataset.
    The dataset file weights over 3 GB and contains over 6 million swipe gestures.
    """

    def __init__(self,
                 data_path: str,
                 kb_keys: List[dict],  # keybard_keys
                 kb_tokenizer,
                 max_traj_len: int,
                 word_tokenizer,  # should contain max word len
                 include_velocities: bool = True,
                 include_accelerations: bool = True,
                 total: Optional[int] = None):
        """
        Args:
            data_path (string): Path to the NeuroSwipe dataset in JSON format.
                A custom version of the dataset is used:
                "grid" property is replaced with "grid_name" property.
        """
        if include_accelerations and not include_velocities:

            raise ValueError("Accelerations are supposed \
                             to be an addition to velocities. Add velocities.")

        self.max_traj_len = max_traj_len
        self.include_velocities = include_velocities
        self.include_accelerations = include_accelerations

        self.word_tokenizer = word_tokenizer

        self.data_list = []
        self._set_data(data_path, kb_keys, kb_tokenizer, self.data_list, total = total)
    

    def _get_key_center(self, hitbox: Dict[str, int]) -> Tuple[int, int]:
        x = hitbox['x'] + hitbox['w'] / 2
        y = hitbox['y'] + hitbox['h'] / 2
        return x, y

    def _coord_to_kb_label(self, x: int, y:int, keys: List[dict]) -> str:
        nearest_kb_label = None
        min_dist = float("inf")
        for key in keys:
            key_x, key_y = self._get_key_center(key['hitbox'])
            dist = (x - key_x)**2 + (y - key_y)**2
            if dist < min_dist:
                min_dist = dist
                if 'label' in key:
                    nearest_kb_label = key['label']
                elif 'action' in key:
                    nearest_kb_label = key['action']  # tokenizer will covert it to <unk>
                else:
                    raise ValueError("Key has no label or action")

        return nearest_kb_label
            

    def _set_data(self,
                  data_path: str,
                  kb_keys: str,
                  kb_tokenizer,
                  data_list: list,
                  total: Optional[int] = None):
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total = total):
                data_list.append(self._get_data_from_json_line(line, kb_keys, kb_tokenizer))


    def _get_dx_dt(self,
                   X: torch.tensor,
                   T: torch.tensor,
                   len: int) -> List[float]:
        """
        Calculates dx/dt for a list of x coordinates and a list of t coordinates.

        Arguments:
        ----------
        X : torch.tensor
            x (position) coordinates.
        T : torch.tensor
            T[i] = time from the beginning of the swipe corresponding to X[i].
        len : int
            Length of the swipe trajectory. Indexes greater than len are ignored.

        """
        dx_dt = torch.zeros_like(X)
        # dx_dt[1:-1] = (X[2:] - X[:-2]) / (T[2:] - T[:-2])
        dx_dt[1:len-1] = (X[2:len] - X[:len-2]) / (T[2:len] - T[:len-2])

        # Example:
        # x0 x1 x2 x3
        # t0 t1 t2 t3
        # dx_dt[0] = 0
        # dx_dt[1] = (x2 - x0) / (t2 - t0)
        # dx_dt[2] = (x3 - x1) / (t3 - t1)
        # dx_dt[3] = 0


        # if True in torch.isnan(dx_dt):
        #     print(dx_dt)
        #     raise ValueError("dx_dt contains NaNs")

        return dx_dt

    def _get_data_from_json_line(self, line, kb_keys, kb_tokenizer) -> Tuple[list, list, list, str]:
        """
        Parses a JSON line and returns a dictionary with data.
        """
        data = json.loads(line)
        word: str = data['word']

        X = array.array('h', data['curve']['x'])
        Y = array.array('h', data['curve']['y'])
        T = array.array('h', data['curve']['t'])        

        kb_labels = [self._coord_to_kb_label(x, y, kb_keys) for x,y in zip(X, Y)]
        kb_tokens = [kb_tokenizer.get_token(label) for label in kb_labels]
        kb_tokens += [kb_tokenizer.get_token('<pad>')] * (self.max_traj_len - len(kb_labels))
        kb_tokens = array.array('h', kb_tokens)

        return X, Y, T, word, kb_tokens

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        X_list, Y_list, T_list, word, kb_tokens = self.data_list[idx]

        X = torch.zeros(self.max_traj_len, dtype=torch.float32)
        Y = torch.zeros(self.max_traj_len, dtype=torch.float32)
        T = torch.zeros(self.max_traj_len, dtype=torch.float32)
        
        X[:len(X_list)] = torch.tensor(X_list, dtype=torch.float32)
        Y[:len(Y_list)] = torch.tensor(Y_list, dtype=torch.float32)
        T[:len(T_list)] = torch.tensor(T_list, dtype=torch.float32)

        xyt = torch.cat(
            [
                X.reshape(-1, 1),
                Y.reshape(-1, 1),
                T.reshape(-1, 1)
            ],
            axis = 1
        )

        traj_len = len(X_list)

        if self.include_velocities:
            dx_dt = self._get_dx_dt(X, T, traj_len)
            dy_dt = self._get_dx_dt(Y, T, traj_len)
            xyt = torch.cat(
                [
                    xyt,
                    dx_dt.reshape(-1, 1),
                    dy_dt.reshape(-1, 1)
                ],
                axis = 1
            )

        if self.include_accelerations:
            d2x_dt2 = self._get_dx_dt(dx_dt, T, traj_len)
            d2y_dt2 = self._get_dx_dt(dy_dt, T, traj_len)
            xyt = torch.cat(
                [
                    xyt,
                    d2x_dt2.reshape(-1, 1),
                    d2y_dt2.reshape(-1, 1)
                ],
                axis = 1
            )

        traj_pad_mask = torch.zeros(self.max_traj_len, dtype=torch.bool)
        traj_pad_mask[:len(X_list)] = True

        char_seq, word_mask = self.word_tokenizer.tokenize(word)

        kb_tokens = torch.tensor(kb_tokens, dtype=torch.int64)
    
        return xyt, kb_tokens, traj_pad_mask, char_seq, word_mask

In [180]:
def get_kb_keys(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        grids = json.load(f)
        grid = grids[grid_name]
        return grid['keys']

In [207]:
sample_data = r"..\data\data_separated_grid\sample_deleteme__default_only.jsonl"
grid_path =  r"..\data\data_separated_grid\gridname_to_grid.json"
grid_name = "default"

kb_keys = get_kb_keys(grid_name, grid_path)
kb_tokenizer = KeyboardTokenizerv1()
word_tokenizer = CharLevelTokenizerv1("../data/data_separated_grid/voc.txt")


dataset = NeuroSwipeDatasetv1(
    data_path = sample_data,
    kb_keys = kb_keys,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = 299,
    word_tokenizer = word_tokenizer,
    include_velocities = True,
    include_accelerations = True,
    total = 1000
)

100%|██████████| 1000/1000 [00:03<00:00, 280.27it/s]


In [209]:
i = 40
xyt, kb_tokens, traj_pad_mask, char_seq, word_mask = dataset[i]
print(xyt.shape, kb_tokens.shape, traj_pad_mask.shape, char_seq.shape, word_mask.shape)

torch.Size([299, 7]) torch.Size([299]) torch.Size([299]) torch.Size([36]) torch.Size([36])


In [175]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwipeCurveTransformerEncoderv1(nn.Module):
    """
    Transformer-based Curve encoder takes in a sequence of vectors and creates a representation
    of a swipe gesture on a samrtphone keyboard.
    Each vector contains information about finger trajectory at a time step.
    It contains:
    * x coordinate
    * y coordinate
    * Optionally: t
    * Optionally: dx/dt
    * Optionally: dy/dt
    * Optionally: keyboard key that has x and y coordinates within its boundaries
    """

    def __init__(self, input_size, d_model,
                 dim_feedforward, num_layers, num_heads_first, num_heads_other,
                 dropout = 0.1):
        """
        Arguments:
        ----------
        input_size: int
            Size of input vectors.
        d_model: int
            Size of the embeddings (output vectors).
            Should be equal to char embedding size of the decoder.
        dim_feedforward: int
        num_layers: int
            Number of encoder layers including the first layer.

        """
        super().__init__()
        # self.pos_encoder = PositionalEncoding(input_size, dropout)
        self.first_encoder_layer = nn.TransformerEncoderLayer(
            input_size, num_heads_first, dim_feedforward, dropout)
        self.liner = nn.Linear(input_size, d_model)  # to convert embedding to d_model size
        num_layer_after_first = num_layers - 1
        if num_layer_after_first > 0:
            encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads_other, dim_feedforward, dropout)
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        else:
            self.transformer_encoder = None
    

    def forward(self, x, pad_mask: torch.tensor):
        # x = self.pos_encoder(x)
        x = self.first_encoder_layer(x, src_key_padding_mask=pad_mask)
        x = self.liner(x)
        if self.transformer_encoder:
            x = self.transformer_encoder(x, src_key_padding_mask=pad_mask)
        return x



class SwipeCurveTransformerDecoderv1(nn.Module):
    """
    Decodes a swipe gesture representation into a sequence of characters.

    Uses decoder transformer with masked attention to prevent the model from cheating.
    """

    def __init__(self, char_emb_size, nhead, num_decoder_layers,
                 dim_feedforward, dropout, activation = F.relu):
        super().__init__()

        self.decoder_layer = nn.TransformerDecoderLayer(char_emb_size, nhead, dim_feedforward, dropout, activation)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_decoder_layers)
        self.out = nn.Linear(char_emb_size, char_emb_size)
        self.softmax = nn.LogSoftmax(dim=2)
    
    def forward(self, x, memory, tgt_mask, memory_key_padding_mask, tgt_key_padding_mask):
        x = self.transformer_decoder(x,
                                     memory,
                                     tgt_mask=tgt_mask,
                                     memory_key_padding_mask=memory_key_padding_mask,
                                     tgt_key_padding_mask=tgt_key_padding_mask)
        x = self.out(x)
        # x = self.softmax(x)
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len, dropout: float = 0.0):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class SwipeCurveTransformer(nn.Module):
    """
    SwipeCurveTransformer is a sequence-to-sequence model that encodes a sequence of vectors
    representing a swipe gesture into a sequence of characters.
    """

    def _get_mask(self, max_seq_len: int):
        """
        Returns a mask for the decoder transformer.
        """
        mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def __init__(self,
                 input_size,
                 char_emb_size,
                 char_vocab_size,
                 num_encoder_layers,
                 num_decoder_layers,
                 dim_feedforward,
                 num_heads,
                 dropout,
                 max_out_seq_len,
                 activation = F.relu):
        super().__init__()

        curv_emb_size = char_emb_size
        
        self.char_embedding = nn.Embedding(char_vocab_size, char_emb_size)

        self.encoder = SwipeCurveTransformerEncoderv1(
            input_size, curv_emb_size, dim_feedforward, num_encoder_layers, num_heads, dropout)
        self.pos_encoder = PositionalEncoding(char_emb_size, max_out_seq_len)
        self.decoder = SwipeCurveTransformerDecoderv1(
            char_emb_size, num_heads, num_decoder_layers, dim_feedforward, dropout, activation)
        self.out = nn.Linear(char_emb_size, char_emb_size)
        self.softmax = nn.LogSoftmax(dim=2)

        self.mask = self._get_mask(max_out_seq_len)

    def forward(self, x, y, x_pad_mask, y_pad_mask):
        x = self.encoder(x, x_pad_mask)
        y = self.char_embedding(y)
        y = self.pos_encoder(y)
        y = self.decoder(y, x, self.mask, x_pad_mask, y_pad_mask)
        y = self.out(y)
        y = self.softmax(y)
        return y

In [106]:
# def get_mask(max_seq_len: int):
#     """
#     Returns a mask for the decoder transformer.
#     """
#     mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1)
#     mask = mask.masked_fill(mask == 1, float('-inf'))
#     return mask

# mask = get_mask(5)
# print(mask)

# >>>
# tensor([[0., -inf, -inf, -inf, -inf],
#         [0., 0., -inf, -inf, -inf],
#         [0., 0., 0., -inf, -inf],
#         [0., 0., 0., 0., -inf],
#         [0., 0., 0., 0., 0.]])

In [128]:
seq_len = 32
batch_size = 10
in_features = 40



encoder = SwipeCurveTransformerEncoderv1(
    input_size=in_features,
    d_model=128,
    dim_feedforward=128,
    num_layers=1,
    num_heads_first=2,
    num_heads_other=4,
    dropout=0.1)



pad_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
pad_mask[:10, :] = True

# print(pad_mask)

encoded = encoder(torch.rand(seq_len, batch_size, in_features), pad_mask)

encoded.transpose_(0,1)



tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        ...,

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]

In [108]:
decoder = SwipeCurveTransformerDecoderv1(char_emb_size=128, nhead=1, num_decoder_layers=1, dim_feedforward=128, dropout=0.1)

seq_len = 32
batch_size = 10
char_emb_size = 128

def get_mask(max_seq_len: int):
    """
    Returns a mask for the decoder transformer.
    """
    mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

target_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)

decoder(
    torch.rand(batch_size, seq_len, char_emb_size),
    torch.rand(batch_size, seq_len, char_emb_size),
    tgt_mask=torch.rand(seq_len, seq_len).masked_fill(torch.rand(seq_len, seq_len) > 0.5, float('-inf'))).shape

TypeError: SwipeCurveTransformerDecoderv1.forward() missing 2 required positional arguments: 'memory_key_padding_mask' and 'tgt_key_padding_mask'

In [None]:
decoder = SwipeCurveTransformerDecoderv1()

In [184]:
word_char_tokenizer = CharLevelTokenizerv1("../data/data_separated_grid/voc.txt")

In [185]:
word_char_tokenizer.tokenize("троллейбус")

(tensor([ 6, 32, 11, 34, 10, 10,  9, 31, 29, 26,  5, 33, 14, 14, 14, 14, 14, 14,
         14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14]),
 tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False]))

In [186]:
print(word_char_tokenizer.char_to_idx.keys())

dict_keys(['э', '-', 'г', 'к', 'ф', 'с', '<sos>', 'ь', 'ш', 'е', 'л', 'р', 'ы', '<unk>', '<pad>', 'з', 'ж', 'ц', 'н', 'а', 'щ', 'ю', 'в', 'п', 'и', 'х', 'у', 'ъ', 'ч', 'б', 'м', 'й', 'т', '<eos>', 'о', 'д', 'я'])


In [187]:
len(word_char_tokenizer.tokenize('информационно-телекоммуникационной')[0])

36

In [188]:
len(word_char_tokenizer.tokenize('информационно')[0])

36

In [213]:
transformer = SwipeCurveTransformer(
    input_size=7,
    char_emb_size=128,
    char_vocab_size=len(word_char_tokenizer.char_to_idx),
    num_encoder_layers=1,
    num_decoder_layers=2,
    dim_feedforward=128,
    num_heads=1,
    dropout=0.1,
    max_out_seq_len=36)

In [216]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, shuffle=True)

for xyt, kb_tokens, traj_pad_mask, char_seq, word_pad_mask in loader:
    # (batch_size, seq_len, n_point_feats) to (seq_len, batch_size, n_point_feats)
    xyt = torch.transpose(xyt, 0, 1)
    char_seq = torch.transpose(char_seq, 0, 1)

    # print(xyt.shape)
    # print(traj_pad_mask.shape)
    # print(char_seq.shape)
    # print(word_pad_mask.shape)
    # print()

    char_seq_pred = transformer(xyt, char_seq, traj_pad_mask, word_pad_mask)
    break



In [217]:
char_seq_pred.transpose(0,1)[0].shape
char_seq_pred

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        ...,

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]

In [66]:
num_heads = 2
num_decoder_layers = 1
dim_feedforward = 128
dropout = 0.1
activation = F.relu
max_out_seq_len = 32

char_emb_size = 128
seq_len = 32
batch_size = 10

decoder = SwipeCurveTransformerDecoderv1(
    char_emb_size, num_heads, num_decoder_layers, dim_feedforward, dropout, activation)

tgt_mask = torch.triu(torch.ones(max_out_seq_len, max_out_seq_len), diagonal=1)
tgt_mask = tgt_mask.masked_fill(tgt_mask == 1, float('-inf'))

memory_pad_mask = torch.zeros(max_out_seq_len, dtype=torch.bool)
memory_pad_mask[:10] = True
memory_pad_mask = memory_pad_mask.unsqueeze(0).repeat(batch_size, 1)

tgt_pad_mask = torch.zeros(max_out_seq_len, dtype=torch.bool)
tgt_pad_mask[:10] = True
tgt_pad_mask = tgt_pad_mask.unsqueeze(0).repeat(batch_size, 1)

decoder(
    torch.rand(seq_len, batch_size, char_emb_size),
    torch.rand(seq_len, batch_size, char_emb_size),
    tgt_mask=tgt_mask,
    memory_key_padding_mask=memory_pad_mask,
    tgt_key_padding_mask=tgt_pad_mask
).shape



torch.Size([32, 10, 128])