In [None]:
from __future__ import annotations

from collections import defaultdict
from pathlib import Path

import torch
from torch import Tensor, nn, optim
from tqdm import tqdm

from compilation import Assembly, C, Compiler
from compression_filter import CompressionFilter
from datapoint import DataPoint, Label
from dataset_iterator import DatasetIterator

# model definition

In [None]:
class Sunbird(nn.Module):
    def __init__(  # noqa: PLR0913
        self,
        *,
        src_type_vocab_size: int,
        src_text_vocab_size: int,
        trg_type_vocab_size: int,
        trg_text_vocab_size: int,
        src_pad_idx: int,
        max_datapoint_len: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        self.input_type_embedding = nn.Embedding(src_type_vocab_size, d_model)
        self.input_text_embedding = nn.Embedding(src_text_vocab_size, d_model)
        self.input_position_embedding = nn.Embedding(max_datapoint_len, d_model)

        self.output_type_embedding = nn.Embedding(trg_type_vocab_size, d_model)
        self.output_text_embedding = nn.Embedding(trg_text_vocab_size, d_model)
        self.output_position_embedding = nn.Embedding(max_datapoint_len, d_model)

        self.src_pad_idx = src_pad_idx
        self.device = device
        self.dropout = nn.Dropout(dropout)

        self.transformer = nn.Transformer(
            d_model,
            num_heads,
            num_layers,
            num_layers,
            d_ff,
            dropout,
            batch_first=True,
        )

        self.fc_in_src = nn.Linear(2 * d_model, d_model)
        self.fc_in_trg = nn.Linear(2 * d_model, d_model)

        self.fc_type_out = nn.Linear(d_model, trg_type_vocab_size)
        self.fc_text_out = nn.Linear(d_model, trg_text_vocab_size)

    def forward(
        self,
        src_type_ids: Tensor,
        src_text_ids: Tensor,
        trg_type_ids: Tensor,
        trg_text_ids: Tensor,
    ) -> tuple[Tensor, Tensor]:
        device = src_type_ids.device
        batch_size, max_seq_len = src_type_ids.shape

        src_type_embedded = self.dropout(self.input_type_embedding(src_type_ids))
        src_text_embedded = self.dropout(self.input_text_embedding(src_text_ids))
        src_positions = self.input_position_embedding(
            torch.arange(max_seq_len, device=device),
        )
        src_embedded = self.dropout(
            self.fc_in_src(torch.cat((src_type_embedded, src_text_embedded), dim=-1))
            + src_positions,
        )

        trg_type_embedded = self.dropout(self.output_type_embedding(trg_type_ids))
        trg_text_embedded = self.dropout(self.output_text_embedding(trg_text_ids))
        trg_positions = self.output_position_embedding(
            torch.arange(max_seq_len, device=device),
        )
        trg_embedded = self.dropout(
            self.fc_in_trg(torch.cat((trg_type_embedded, trg_text_embedded), dim=-1))
            + trg_positions,
        )

        src_mask = src_type_ids.transpose(0, 1) == self.src_pad_idx
        trg_mask = self.transformer.generate_square_subsequent_mask(max_seq_len)

        transformer_output = self.transformer(
            src_embedded,
            trg_embedded,
            src_key_padding_mask=src_mask.transpose(0, 1),
            tgt_mask=trg_mask,
        )
        return (
            self.fc_type_out(transformer_output),
            self.fc_text_out(transformer_output),
        )

# generation

In [None]:
def generate(  # noqa: PLR0913
    model: Sunbird,
    asm_code: str,
    vocab: Vocab,
    max_generation_length: int,
    max_inp_len: int,
    device: torch.device,
) -> tuple[list[str], list[str]]:
    dp = DataPoint(Label(0), C(""))
    dp.asm.append(Assembly(asm_code, Compiler(), 0))
    token_pairs = CompressionFilter(dp).all_tokens()[0][0]

    types: list[int] = []
    texts: list[int] = []

    for typ, txt in token_pairs:
        types.append(vocab.src_type[typ])
        texts.append(vocab.src_text[txt])

    def pad_or_truncate(vec: list[int]) -> list[int]:
        if len(vec) < max_inp_len:
            vec = vec + [vocab.src_pad_idx] * (max_inp_len - len(vec))
        elif len(vec) > max_inp_len:
            vec = vec[:max_inp_len]
        return vec

    type_tensor: Tensor = torch.tensor(pad_or_truncate(types), device=device).unsqueeze(
        0,
    )
    text_tensor: Tensor = torch.tensor(pad_or_truncate(texts), device=device).unsqueeze(
        0,
    )

    generated_types = torch.tensor(
        pad_or_truncate([vocab.bos_idx]),
        device=device,
    ).unsqueeze(0)
    generated_texts = torch.tensor(
        pad_or_truncate([vocab.bos_idx]),
        device=device,
    ).unsqueeze(0)

    type_gen_idx = 1
    text_gen_idx = 1

    model.eval()

    with torch.no_grad():
        for _ in range(max_generation_length):
            o_type, o_text = model(
                type_tensor,
                text_tensor,
                generated_types,
                generated_texts,
            )

            next_type_logits = o_type[:, -1, :]
            next_text_logits = o_text[:, -1, :]

            next_type_probs = nn.functional.softmax(next_type_logits, dim=-1)
            next_text_probs = nn.functional.softmax(next_text_logits, dim=-1)

            next_type = torch.multinomial(next_type_probs, num_samples=1)
            next_text = torch.multinomial(next_text_probs, num_samples=1)

            generated_types[0][type_gen_idx] = next_type.item()
            generated_texts[0][text_gen_idx] = next_text.item()

            type_gen_idx += 1
            text_gen_idx += 1

            if next_type.item() == vocab.eos_idx or next_text.item() == vocab.eos_idx:
                break

    inv_type = {v: k for k, v in vocab.trg_type.items()}
    inv_text = {v: k for k, v in vocab.trg_text.items()}

    inv_type[0] = "<pad>"
    inv_text[0] = "<pad>"

    inv_type[1] = "<bos>"
    inv_text[1] = "<bos>"

    inv_type[2] = "<eos>"
    inv_text[2] = "<eos>"

    str_types = [inv_type.get(x, "<unk>") for x in generated_types[0].tolist()]
    str_texts = [inv_text.get(x, "<unk>") for x in generated_texts[0].tolist()]

    return (
        [x for x in str_types if x != "<pad>"],
        [x for x in str_texts if x != "<pad>"],
    )

# vocab

In [None]:
class Vocab:
    def __init__(self) -> None:
        self.src_pad_idx = 0
        self.bos_idx = 1
        self.eos_idx = 2
        # reserving [0] for pad
        self.src_type: dict[str, int] = defaultdict(
            lambda: len(self.src_type) + 1,
        )
        self.src_text: dict[str, int] = defaultdict(
            lambda: len(self.src_text) + 1,
        )
        # reserving [1] for bos and [2] for eos
        self.trg_type: dict[str, int] = defaultdict(
            lambda: len(self.trg_type) + 3,
        )
        self.trg_text: dict[str, int] = defaultdict(
            lambda: len(self.trg_text) + 3,
        )

# batch processing setup

In [None]:
class BatchBuilder:
    def __init__(  # noqa: PLR0913
        self,
        dataset_csv_path: str,
        batch_size: int,
        max_seq_len: int,
        vocab: Vocab,
        device: torch.device,
        compile_chunk_size: int = 1024,
    ) -> None:
        self.iterator = DatasetIterator(dataset_csv_path)
        self.num_asm_per_c = len(
            CompressionFilter(self.iterator.take(1)[0]).all_tokens(),
        )
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len
        self.vocab = vocab
        self.device = device
        self.compile_chunk_size = compile_chunk_size

        self.compiled: list[tuple[list[tuple[str, str]], list[tuple[str, str]]]] = []

    def __pad_or_truncate(self, vec: list[int]) -> list[int]:
        if len(vec) < self.max_seq_len:
            vec = vec + [self.vocab.src_pad_idx] * (max_seq_len - len(vec))
        elif len(vec) > self.max_seq_len:
            vec = vec[: self.max_seq_len]
        return vec

    def next_batch(self) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        if len(self.compiled) < self.batch_size:
            self.compiled.extend(
                [
                    pair
                    for dp in self.iterator.take(self.compile_chunk_size)
                    for pair in CompressionFilter(dp).all_tokens()
                ],
            )

        src_types, src_texts, trg_types, trg_texts = [], [], [], []

        for _ in range(self.batch_size):
            asm_pairs, c_pairs = self.compiled.pop(0)

            src_types.append(
                self.__pad_or_truncate(
                    [self.vocab.src_type[pair[0]] for pair in asm_pairs],
                ),
            )
            src_texts.append(
                self.__pad_or_truncate(
                    [self.vocab.src_text[pair[1]] for pair in asm_pairs],
                ),
            )
            trg_types.append(
                self.__pad_or_truncate(
                    [self.vocab.bos_idx]
                    + [self.vocab.trg_type[pair[0]] for pair in c_pairs]
                    + [self.vocab.eos_idx],
                ),
            )
            trg_texts.append(
                self.__pad_or_truncate(
                    [self.vocab.bos_idx]
                    + [self.vocab.trg_text[pair[1]] for pair in c_pairs]
                    + [self.vocab.eos_idx],
                ),
            )

        return (
            torch.tensor(src_types, device=self.device),
            torch.tensor(src_texts, device=self.device),
            torch.tensor(trg_types, device=self.device),
            torch.tensor(trg_texts, device=self.device),
        )

# training

In [None]:
def training_loop(  # noqa: PLR0913
    *,
    model: nn.Module,
    num_epochs: int,
    src_pad_idx: int,
    trg_type_vocab_size: int,
    trg_text_vocab_size: int,
    batch_builder: BatchBuilder,
    learning_rate: float = 3e-4,
    save_interval: int = 5,
    num_batches_per_epoch: int = 100,
    save_dir: str = "checkpoints",
) -> None:
    # import error fixed in pytorch upstream fb1c580
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # type: ignore[attr-defined]
    criterion_type = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
    criterion_text = nn.CrossEntropyLoss(ignore_index=src_pad_idx)

    Path(save_dir).mkdir(exist_ok=True)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        progress_bar = tqdm(
            range(num_batches_per_epoch),
            desc=f"epoch {epoch+1}/{num_epochs}",
        )

        for batch in progress_bar:
            src_type_ids, src_text_ids, trg_type_ids, trg_text_ids = (
                batch_builder.next_batch()
            )

            optimizer.zero_grad()

            type_output, text_output = model(
                src_type_ids,
                src_text_ids,
                trg_type_ids,
                trg_text_ids,
            )

            loss_type = criterion_type(
                type_output.view(-1, trg_type_vocab_size),
                trg_type_ids.view(-1),
            )
            loss_text = criterion_text(
                text_output.view(-1, trg_text_vocab_size),
                trg_text_ids.view(-1),
            )
            loss = loss_type + loss_text

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            progress_bar.set_postfix({"loss": total_loss / (batch + 1)})

        avg_loss = total_loss / num_batches_per_epoch
        print(f"epoch {epoch+1}/{num_epochs}, average loss: {avg_loss:.4f}")

        if (epoch + 1) % save_interval == 0:
            checkpoint_path = Path(save_dir) / f"sunbird_model_epoch_{epoch+1}.pt"

            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": avg_loss,
                },
                checkpoint_path,
            )
            print(f"model saved to {checkpoint_path}")

In [None]:
batch_size: int = 1
max_seq_len: int = 2048  # ~= 95th percentile for tokenized asm code length

# calculated based off train.csv
src_type_vocab_size = 22
src_text_vocab_size = 110000
trg_type_vocab_size = 112
trg_text_vocab_size = 40000

epochs: int = 100
batches_per_epoch: int = 100
save_interval = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(f"training on {torch.cuda.get_device_name(torch.cuda.current_device())}...")

vocab = Vocab()

batch_builder = BatchBuilder(
    "dataset/train.csv",
    batch_size,
    max_seq_len,
    vocab,
    device,
)

model = Sunbird(
    src_type_vocab_size=src_type_vocab_size,
    src_text_vocab_size=src_text_vocab_size,
    trg_type_vocab_size=trg_type_vocab_size,
    trg_text_vocab_size=trg_text_vocab_size,
    src_pad_idx=vocab.src_pad_idx,
    max_datapoint_len=max_seq_len,
).to(device)


training_loop(
    model=model,
    num_epochs=epochs,
    num_batches_per_epoch=batches_per_epoch,
    src_pad_idx=vocab.src_pad_idx,
    trg_type_vocab_size=trg_type_vocab_size,
    trg_text_vocab_size=trg_text_vocab_size,
    batch_builder=batch_builder,
)