In [4]:
%run 5_model-modules.ipynb
%run 6_output-heads.ipynb

torch.Size([8, 50])
torch.Size([8, 50, 512])
torch.Size([8, 50, 512])
torch.Size([8, 25])
####################################################
torch.Size([8, 50, 512])
torch.Size([8, 50])
torch.Size([8, 25, 512])
torch.Size([8, 25])
torch.Size([8, 25, 512])


In [5]:
import torch

In [6]:
class Model(torch.nn.Module):
    def __init__(
        self,
        sym_len,     # vocabulary size
        max_sam_len, # max size of sample (input)
        max_nam_len, # max size of name (output)
        unit_cat_len,
        tax_cat_len,
        emb_dim=512,
        num_heads=8,
        ff_dim=2048,
        dropout=0.1,
        num_layers=4
    ):
        super().__init__()

        self.encoder_embedding = Embedding(
            sym_len=sym_len,
            max_seq_len=max_sam_len,
            emb_dim=emb_dim
        )

        self.decoder_embedding = Embedding(
            sym_len=sym_len,
            max_seq_len=max_nam_len,
            emb_dim=emb_dim
        )

        self.encoder = Encoder(
            emb_dim=emb_dim,
            num_heads=num_heads,
            ff_dim=ff_dim,
            dropout=dropout,
            num_layers=num_layers
        )

        self.decoder = Decoder(
            emb_dim=emb_dim,
            num_heads=num_heads,
            ff_dim=ff_dim,
            dropout=dropout,
            num_layers=num_layers
        )

        self.multihead = MultiHead(
            sym_len=sym_len,
            unit_cat_len=unit_cat_len,
            tax_cat_len=tax_cat_len,
            emb_dim=emb_dim
        )

    def forward(
            self,
            encoder_tokens,
            decoder_tokens,
            encoder_mask,
            decoder_mask
    ):
        encoder_embedding = self.encoder_embedding(encoder_tokens)
        decoder_embedding = self.decoder_embedding(decoder_tokens)

        encoder_output = self.encoder(encoder_embedding, encoder_mask)

        decoder_output = self.decoder(
            decoder_input=decoder_embedding,
            encoder_output=encoder_output,
            name_mask=decoder_mask,
            sample_mask=encoder_mask
        )

        outputs = self.multihead(
            encoder_output=encoder_output,
            decoder_output=decoder_output,
        )

        return outputs

    @torch.no_grad()
    def generate_name(
        self,
        encoder_output,
        sos_id,
        eos_id,
        max_len=50,
        device="cpu",
    ):
        """
        encoder_output: (B, enc_len, emb_dim)
        decoder: your Decoder() module
        embedding: your Embedding() module
        multihead: your MultiHead() module
        """

        batch_size = encoder_output.shape[0]

        # Start with <SOS>
        generated = torch.full((batch_size, 1), sos_id, dtype=torch.long, device=device)

        for _ in range(max_len):

            # embed current decoder sequence
            dec_inp = self.decoder_embedding(generated)          # (B, cur_len, emb_dim)

            # dummy masks (all tokens allowed)
            name_mask = torch.zeros(dec_inp.shape[:2], dtype=torch.bool, device=device)
            sample_mask = torch.zeros(encoder_output.shape[:2], dtype=torch.bool, device=device)

            # run decoder
            dec_out = self.decoder(
                decoder_input=dec_inp,
                encoder_output=encoder_output,
                name_mask=name_mask,
                sample_mask=sample_mask,
            )                                       # (B, cur_len, emb_dim)

            # head â†’ logits
            out = self.multihead(encoder_output, dec_out)
            logits = out["name_logits"]             # (B, cur_len, vocab)

            # take last time step
            next_token_logits = logits[:, -1, :]    # (B, vocab)

            # greedy select
            next_token = next_token_logits.argmax(dim=-1)  # (B,)

            # append
            next_token = next_token.unsqueeze(1)
            generated = torch.cat([generated, next_token], dim=1)

            # stop if EOS everywhere
            if (next_token == eos_id).all():
                break

        return generated