We need different heads for different outputs:
- **Sequence** (_CE Loss_, the only one that uses decoder)
    - name
- **Classification** (_CE Loss_)
    - unit
    - tax category
- **Regression** (_MSE Loss_)
    - amount
    - quantity
    - price
    - total price

In [1]:
import torch

In [2]:
class SeqHead(torch.nn.Module):
    def __init__(self, sym_len, emb_dim=512):
        super().__init__()

        self.linear = torch.nn.Linear(in_features=emb_dim, out_features=sym_len)

    def forward(self, decoder_output):
        logits = self.linear(decoder_output)

        return logits

In [None]:
class CatHead(torch.nn.Module):
    def __init__(self, cat_len, emb_dim=512):
        super().__init__()

        self.linear = torch.nn.Linear(in_features=emb_dim, out_features=cat_len)

    def forward(self, encoder_output):
        first_token = encoder_output[:, 0, :]

        logits = self.linear(first_token)

        return logits

In [None]:
class RegHead(torch.nn.Module):
    def __init__(self, emb_dim=512):
        super().__init__()

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=emb_dim, out_features=emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=emb_dim, out_features=1)
        )

    def forward(self, encoder_output):
        first_token = encoder_output[:, 0, :]

        pred = self.mlp(first_token)

        return pred

In [3]:
class MultiHead(torch.nn.Module):
    def __init__(self, sym_len, unit_cat_len, tax_cat_len, emb_dim=512):
        super().__init__()

        self.name_head = SeqHead(sym_len=sym_len, emb_dim=emb_dim)

        self.unit_head = CatHead(cat_len=unit_cat_len, emb_dim=emb_dim)
        self.tax_head =  CatHead(cat_len=tax_cat_len,  emb_dim=emb_dim)

        self.amount_head =      RegHead(emb_dim=emb_dim)
        self.quantity_head =    RegHead(emb_dim=emb_dim)
        self.price_head =       RegHead(emb_dim=emb_dim)
        self.total_head = RegHead(emb_dim=emb_dim)

    def forward(self, encoder_output, decoder_output):
        return {
            "name_logits":   self.name_head(decoder_output),
            "unit_logits":   self.unit_head(encoder_output),
            "tax_logits":    self.tax_head(encoder_output),
            "amount_pred":   self.amount_head(encoder_output),
            "quantity_pred": self.quantity_head(encoder_output),
            "price_pred":    self.price_head(encoder_output),
            "total_pred":    self.total_head(encoder_output),
        }