In [None]:
from __future__ import annotations

import pickle
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Generator

import torch
from torch import Tensor, nn, optim
from tqdm import tqdm

if TYPE_CHECKING:
    from data import DataPoint

# model definition

In [None]:
class Sunbird(nn.Module):
    def __init__(  # noqa: PLR0913
        self,
        *,
        src_type_vocab_size: int,
        src_text_vocab_size: int,
        trg_type_vocab_size: int,
        trg_text_vocab_size: int,
        src_pad_idx: int,
        max_datapoint_len: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        self.input_type_embedding = nn.Embedding(src_type_vocab_size, d_model)
        self.input_text_embedding = nn.Embedding(src_text_vocab_size, d_model)
        self.input_position_embedding = nn.Embedding(max_datapoint_len, d_model)

        self.output_type_embedding = nn.Embedding(trg_type_vocab_size, d_model)
        self.output_text_embedding = nn.Embedding(trg_text_vocab_size, d_model)
        self.output_position_embedding = nn.Embedding(max_datapoint_len, d_model)

        self.src_pad_idx = src_pad_idx
        self.device = device
        self.dropout = nn.Dropout(dropout)

        self.transformer = nn.Transformer(
            d_model,
            num_heads,
            num_layers,
            num_layers,
            d_ff,
            dropout,
            batch_first=True,
        )

        self.fc_in_src = nn.Linear(2 * d_model, d_model)
        self.fc_in_trg = nn.Linear(2 * d_model, d_model)

        self.fc_type_out = nn.Linear(d_model, trg_type_vocab_size)
        self.fc_text_out = nn.Linear(d_model, trg_text_vocab_size)

    def forward(
        self,
        src_type_ids: Tensor,
        src_text_ids: Tensor,
        trg_type_ids: Tensor,
        trg_text_ids: Tensor,
    ) -> tuple[Tensor, Tensor]:
        device = src_type_ids.device
        batch_size, max_seq_len = src_type_ids.shape

        src_type_embedded = self.dropout(self.input_type_embedding(src_type_ids))
        src_text_embedded = self.dropout(self.input_text_embedding(src_text_ids))
        src_positions = self.input_position_embedding(
            torch.arange(max_seq_len, device=device),
        )
        src_embedded = self.dropout(
            self.fc_in_src(torch.cat((src_type_embedded, src_text_embedded), dim=-1))
            + src_positions,
        )

        trg_type_embedded = self.dropout(self.output_type_embedding(trg_type_ids))
        trg_text_embedded = self.dropout(self.output_text_embedding(trg_text_ids))
        trg_positions = self.output_position_embedding(
            torch.arange(max_seq_len, device=device),
        )
        trg_embedded = self.dropout(
            self.fc_in_trg(torch.cat((trg_type_embedded, trg_text_embedded), dim=-1))
            + trg_positions,
        )

        src_mask = src_type_ids.transpose(0, 1) == self.src_pad_idx
        trg_mask = self.transformer.generate_square_subsequent_mask(max_seq_len)

        transformer_output = self.transformer(
            src_embedded,
            trg_embedded,
            src_key_padding_mask=src_mask.transpose(0, 1),
            tgt_mask=trg_mask,
        )
        return (
            self.fc_type_out(transformer_output),
            self.fc_text_out(transformer_output),
        )

# dataset interface

In [None]:
class DatasetInterface:
    def __init__(
        self,
        *,
        dataset: list[DataPoint],
        max_inp_len: int,
        pad_idx: int,
    ) -> None:
        self.dataset: list[DataPoint] = dataset
        self.max_inp_len: int = max_inp_len
        self.pad_idx: int = pad_idx
        self.asm_type_map, self.asm_text_map, self.c_type_map, self.c_text_map = (
            self.build_vocabulary(dataset)
        )

        self.generator = self.filtered_dataset()

    def build_vocabulary(
        self,
        dataset: list[DataPoint],
    ) -> tuple[dict[str, int], dict[str, int], dict[str, int], dict[str, int]]:
        self.dataset = dataset
        asm_type_map: dict[str, int] = defaultdict(lambda: len(asm_type_map) + 1)
        asm_text_map: dict[str, int] = defaultdict(lambda: len(asm_text_map) + 1)

        c_type_map: dict[str, int] = defaultdict(lambda: len(c_type_map) + 1)
        c_text_map: dict[str, int] = defaultdict(lambda: len(c_text_map) + 1)

        for dp in dataset:
            for typ, txt in dp.c_code.as_tokens():
                c_type_map[typ]
                c_text_map[txt]
            for asm_code in dp.asm:
                for typ, txt in asm_code.as_tokens():
                    asm_type_map[typ]
                    asm_text_map[txt]

        return (asm_type_map, asm_text_map, c_type_map, c_text_map)

    def filtered_dataset(
        self,
    ) -> Generator[tuple[list[int], list[int], list[int], list[int]], None, None]:
        for dp in self.dataset:
            for asm in dp.asm:
                asm_tokens = list(asm.as_tokens())
                c_tokens = list(dp.c_code.as_tokens())
                if len(asm_tokens) < self.max_inp_len:
                    asm_types, asm_text = zip(*asm_tokens)
                    c_types, c_text = zip(*c_tokens)
                    yield (
                        [self.asm_type_map[typ] for typ in asm_types]
                        + [self.pad_idx] * (self.max_inp_len - len(asm_types)),
                        [self.asm_text_map[txt] for txt in asm_text]
                        + [self.pad_idx] * (self.max_inp_len - len(asm_text)),
                        [self.c_type_map[typ] for typ in c_types]
                        + [self.pad_idx] * (self.max_inp_len - len(c_types)),
                        [self.c_text_map[txt] for txt in c_text]
                        + [self.pad_idx] * (self.max_inp_len - len(c_text)),
                    )

    def take(self, n: int) -> Tensor:
        return torch.tensor(list(zip(*[next(self.generator) for _ in range(n)])))

# load generated data

In [None]:
def from_pickle(input_file: str) -> list[DataPoint]:
    with Path(input_file).open("rb") as file:
        return list(pickle.load(file))


train: list[DataPoint] = from_pickle("dataset/train_compiled.pkl")
test: list[DataPoint] = from_pickle("dataset/test_compiled.pkl")

# training

In [None]:
def training_loop(  # noqa: PLR0913
    *,
    model: nn.Module,
    num_epochs: int,
    batch_size: int,
    src_pad_idx: int,
    interface: DatasetInterface,
    learning_rate: float = 3e-4,
    save_interval: int = 5,
    num_batches_per_epoch: int = 100,
    save_dir: str = "checkpoints",
    device: torch.device,
) -> None:
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # type: ignore[attr-defined]
    # import error fixed in pytorch upstream fb1c580
    criterion_type = nn.CrossEntropyLoss(
        ignore_index=src_pad_idx,
    )
    criterion_text = nn.CrossEntropyLoss(ignore_index=src_pad_idx)

    Path(save_dir).mkdir(exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        progress_bar = tqdm(
            range(num_batches_per_epoch),
            desc=f"epoch {epoch+1}/{num_epochs}",
        )

        for batch in progress_bar:
            src_type_ids, src_text_ids, trg_type_ids, trg_text_ids = interface.take(
                batch_size,
            ).to(device)

            optimizer.zero_grad()

            type_output, text_output = model(
                src_type_ids,
                src_text_ids,
                trg_type_ids,
                trg_text_ids,
            )

            loss_type = criterion_type(
                type_output.view(-1, trg_type_vocab_size),
                trg_type_ids.view(-1),
            )
            loss_text = criterion_text(
                text_output.view(-1, trg_text_vocab_size),
                trg_text_ids.view(-1),
            )
            loss = loss_type + loss_text

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            progress_bar.set_postfix({"loss": total_loss / (batch + 1)})

        avg_loss = total_loss / num_batches_per_epoch
        print(f"epoch {epoch+1}/{num_epochs}, average loss: {avg_loss:.4f}")

        if (epoch + 1) % save_interval == 0:
            checkpoint_path = Path(save_dir) / f"sunbird_model_epoch_{epoch+1}.pt"

            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": avg_loss,
                },
                checkpoint_path,
            )
            print(f"model saved to {checkpoint_path}")

In [None]:
num_c_snippets: int = len(train)  # 300m tokens
num_asm_snippets: int = sum(len(dp.asm) for dp in train)  # 2.5b tokens

print(
    f"training with {num_c_snippets} C snippets compiled to {num_asm_snippets} assembly snippets",
)

In [None]:
batch_size: int = 1
max_seq_len: int = 2048  # ~= 95th percentile for tokenized asm code length
src_pad_idx: int = 0

interface = DatasetInterface(
    dataset=train,
    max_inp_len=max_seq_len,
    pad_idx=src_pad_idx,
)

src_type_vocab_size: int = len(interface.asm_type_map)
src_text_vocab_size: int = len(interface.asm_text_map)
trg_type_vocab_size: int = len(interface.c_type_map)
trg_text_vocab_size: int = len(interface.c_text_map)
epochs: int = 50
batches_per_epoch: int = 100
save_interval = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(f"training on {torch.cuda.get_device_name(torch.cuda.current_device())}...")

model = Sunbird(
    src_type_vocab_size=src_type_vocab_size,
    src_text_vocab_size=src_text_vocab_size,
    trg_type_vocab_size=trg_type_vocab_size,
    trg_text_vocab_size=trg_text_vocab_size,
    src_pad_idx=src_pad_idx,
    max_datapoint_len=max_seq_len,
).to(device)

training_loop(
    model=model,
    num_epochs=epochs,
    batch_size=batch_size,
    src_pad_idx=src_pad_idx,
    device=device,
    interface=interface,
)