In [22]:
from doctest import master

import torch

In [23]:
class Embedding(torch.nn.Module):
    def __init__(self, sym_len, max_seq_len, emb_dim=512):
        super().__init__()

        self.sym_len = sym_len
        self.max_seq_len = max_seq_len

        self.sym_embedding = torch.nn.Embedding(num_embeddings=sym_len,     embedding_dim=emb_dim)
        self.pos_embedding = torch.nn.Embedding(num_embeddings=max_seq_len, embedding_dim=emb_dim)

    def forward(self, x):
        B, L = x.shape

        pos = torch.arange(L, device=x.device).unsqueeze(0) # (1, L)

        sym_emb = self.sym_embedding(x)   # (B, L, D)
        pos_emb = self.pos_embedding(pos) # (1, L, D)

        x = sym_emb + pos_emb

        return x


Quick check

In [24]:
batch_size = 8
sym_len = 90
max_sam_len = 50

sample = torch.randint(sym_len, (batch_size, max_sam_len)).cpu()
print(sample.shape) # (8, 50)

embedding = Embedding(sym_len=sym_len, max_seq_len=max_sam_len, emb_dim=512).cpu()

x = embedding(sample)
print(x.shape) # Expected (8, 50, 512)

torch.Size([8, 50])
torch.Size([8, 50, 512])


In [25]:
class Encoder(torch.nn.Module):
    def __init__(self, emb_dim=512, num_heads=8, ff_dim=2048, dropout=0.1, num_layers=4):
        super().__init__()

        encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True,
            norm_first=True, # More stable training
        )

        self.encoder = torch.nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_layers,
            enable_nested_tensor=False, # For norm first
        )

        self.final_norm = torch.nn.LayerNorm(emb_dim)

    def forward(self, x, sample_mask):
        x = self.encoder(x, src_key_padding_mask=sample_mask)
        x = self.final_norm(x)

        return x

Quick check

In [26]:
encoder = Encoder().cpu()

sample_mask = torch.randint(2, (batch_size, max_sam_len), dtype=torch.bool).cpu()

x = encoder(x, sample_mask=sample_mask)
print(x.shape) # Expected (8, 50, 512)

torch.Size([8, 50, 512])


In [27]:
class Decoder(torch.nn.Module):
    def __init__(self, emb_dim=512, num_heads=8, ff_dim=2048, dropout=0.1, num_layers=4):
        super().__init__()

        decoder_layer = torch.nn.TransformerDecoderLayer(
            d_model=emb_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True,
            norm_first=True,
        )

        self.decoder = torch.nn.TransformerDecoder(
            decoder_layer=decoder_layer,
            num_layers=num_layers,
        )

        self.final_norm = torch.nn.LayerNorm(emb_dim)

    def forward(self, decoder_input, encoder_output, name_mask, sample_mask):
        name_seq_len = decoder_input.size(1)

        causal_mask = torch.triu(torch.ones((name_seq_len, name_seq_len), dtype=torch.bool, device=decoder_input.device), diagonal=1)

        x = self.decoder(
            tgt=decoder_input,
            memory=encoder_output,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=name_mask,
            memory_key_padding_mask=sample_mask
        )

        x = self.final_norm(x)

        return x

Quick check

In [28]:
max_nam_len = 25

decoder_input = torch.randint(sym_len, (batch_size, max_nam_len)).cpu()
print(decoder_input.shape) # (8, 25)

decoder_embedding = Embedding(sym_len=sym_len, max_seq_len=max_nam_len)

decoder_input = decoder_embedding(decoder_input)

decoder = Decoder().cpu()

name_mask = torch.randint(2, (batch_size, max_nam_len), dtype=torch.bool).cpu()

print("####################################################")
print(x.shape)             # (8, 50, 512)
print(sample_mask.shape)   # (8, 50)
print(decoder_input.shape) # (8, 25, 512)
print(name_mask.shape)     # (8, 25)

y = decoder(
    decoder_input=decoder_input,
    encoder_output=x,
    name_mask=name_mask,
    sample_mask=sample_mask,
)

print(y.shape)

torch.Size([8, 25])
####################################################
torch.Size([8, 50, 512])
torch.Size([8, 50])
torch.Size([8, 25, 512])
torch.Size([8, 25])
torch.Size([8, 25, 512])
