TODO

1. remove dropout from exported programs
2. вероятно нужно сделать модули FullEncoder 
3. проверить, игнорируются поля модели, которые не используются в forward (например, есть ли разница между текущей инмплементацией `Decode` и имплементацией, где вся модель хранится как удинственное поле Decode) 

In [None]:
from typing import Dict, Union, Tuple
import os
import array

import torch
from torch import Tensor
from torch.export import export, ExportedProgram, Dim
from executorch.exir import EdgeProgramManager, to_edge, to_edge_transform_and_lower
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

from model import MODEL_GETTERS_DICT
from feature_extractors import get_val_transform
from ns_tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from word_generators_v2 import BeamGenerator

In [None]:
# MODEL_NAME = 'v3_weighted_and_traj_transformer_bigger'
# CHECKPOINT_ROOT_PATH = '../../../checkpoints_for_executorch/my_weighted_features/'
# CHECKPOINT_PATH = os.path.join(CHECKPOINT_ROOT_PATH, 'weighted_transformer_bigger-default--epoch=90-val_loss=0.444-val_word_level_accuracy=0.873.ckpt')  # 'weighted_transformer_bigger-default--epoch=115-val_loss=0.440-val_word_level_accuracy=0.879.ckpt'


In [None]:
RAW_DATASET_ITEM_EXAMPLE = (
    array.array('h', [567, 567, 507, 424, 380, 348, 337, 332, 330, 329, 327, 326, 326]),
    array.array('h', [66, 66, 101, 161, 196, 230, 240, 245, 247, 249, 251, 251, 251]),
    array.array('h', [0, 3, 24, 52, 75, 90, 106, 129, 145, 161, 177, 195, 209]),
    'default',
    'на')

GRIDNAME_TO_GRID_PATH = '../data/data_separated_grid/gridname_to_grid.json'

In [None]:
# COMMAND LINE ARGUMENTS EMULATION

MODEL_NAME = 'v3_nearest_and_traj_transformer_bigger'
CHECKPOINT_ROOT_PATH = '../../../checkpoints_for_executorch/my_nearest_features/'
CHECKPOINT_PATH = os.path.join(CHECKPOINT_ROOT_PATH, 'v3_nearest_and_traj_transformer_bigger-default--epoch=73-val_loss=0.444-val_word_level_accuracy=0.872.ckpt')
TRANSFORM_NAME =  "traj_feats_and_nearest_key"

DATA_ROOT = '../data/data_separated_grid'

gridname_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
voc_path=os.path.join(DATA_ROOT, "voc.txt")
char_tokenizer = CharLevelTokenizerv2(voc_path)
kb_tokenizer = KeyboardTokenizerv1()

USE_TIME = False
USE_VELOCITY = True
USE_ACCELERATION = True


transform = get_val_transform(
    gridname_to_grid_path=GRIDNAME_TO_GRID_PATH,
    grid_names=['default'],
    transform_name=TRANSFORM_NAME,
    char_tokenizer=char_tokenizer,
    uniform_noise_range=0,
    include_time=USE_TIME,
    include_velocities=USE_VELOCITY,
    include_accelerations=USE_ACCELERATION,
    dist_weights_func=None,  # Fill if weighted version is used
    ds_paths_list=[],
)

In [None]:
def remove_prefix(s: str, prefix: str) -> str:
    if s.startswith(prefix):
        s = s[len(prefix):]
    return s


def get_state_dict_from_checkpoint(ckpt: dict) -> Dict[str, torch.Tensor]:
    return {remove_prefix(k, 'model.'): v for k, v in ckpt['state_dict'].items()}


def _prepare_encoder_input(encoder_in: Union[Tensor, Tuple[Tensor, Tensor]], 
                           device: str, batch_first: bool
                           ) -> Tuple[Tensor, Tensor]:
    is_tensor = None
    if isinstance(encoder_in, Tensor):
        is_tensor = True
        encoder_in = [encoder_in]
    else:
        is_tensor = False

    encoder_in = [el.unsqueeze(0) for el in encoder_in]
    encoder_in = [el.to(device) for el in encoder_in]

    if not batch_first:
        encoder_in = [el.transpose(0, 1) for el in encoder_in]
    return encoder_in[0] if is_tensor else encoder_in

In [None]:
state_dict = get_state_dict_from_checkpoint(
    torch.load(CHECKPOINT_PATH, map_location='cpu', weights_only=True))


model = MODEL_GETTERS_DICT[MODEL_NAME]()
model.load_state_dict(state_dict)

In [None]:
(encoder_in, decoder_in), decoder_out_target = transform(RAW_DATASET_ITEM_EXAMPLE)
encoder_in = _prepare_encoder_input(encoder_in, 'cpu', batch_first=False)
if isinstance(encoder_in, list):
    encoder_in = tuple(encoder_in)
decoder_in = decoder_in.unsqueeze(1)

In [None]:
encoded = model.encode(
    encoder_in, 
    None)

In [None]:
decoded = model.decode(decoder_in, encoded, None, None)

In [None]:
decoded

In [None]:
model = model.eval()

In [None]:
decoder_in.shape

In [None]:
decoder_in

In [None]:
class Encode(torch.nn.Module):
    def __init__(self, model) -> None:
        super().__init__()
        self.enc_in_emb_model = model.enc_in_emb_model
        self.encoder = model.encoder

    def forward(self, encoder_in):
        x = self.enc_in_emb_model(encoder_in)
        return self.encoder(x, src_key_padding_mask = None)

def _get_casual_mask(sz: int, device='cpu'):
    return torch.tril(
        torch.ones((sz, sz), dtype=torch.bool, device=device),
    )


class Decode(torch.nn.Module):
    def __init__(self, model) -> None:
        super().__init__()
        self.dec_in_emb_model = model.dec_in_emb_model
        self.decoder = model.decoder
        self._get_mask = model._get_mask
        self.out = model.out
        # MAX_WORD_LEN = 35
        # causal_mask = self._get_casual_mask(MAX_WORD_LEN).to(device=self.model.device)
        # self.register_buffer("causal_mask", causal_mask, persistent=False)


    def forward(self, decoder_in, x_encoded):
        y = self.dec_in_emb_model(decoder_in)
        tgt_mask = _get_casual_mask(y.size(0))  # = self.causal_mask[y.size(0):, y.size(0):]
        dec_out = self.decoder(
            y, x_encoded, tgt_mask=tgt_mask, 
            memory_key_padding_mask=None, 
            tgt_key_padding_mask=None,
            tgt_is_causal=True)
        return self.out(dec_out)
    


MAX_SWIPE_LEN = 299
MAX_WORD_LEN = 35
dim_swipe_seq = Dim("dim_swipe_seq", min=1, max=MAX_SWIPE_LEN)
dim_char_seq = Dim("dim_char_seq", min=1, max=MAX_WORD_LEN)

encoder_dynamic_shapes = {"encoder_in": ({0: dim_swipe_seq}, {0: dim_swipe_seq})}
decoder_dynamic_shapes = {
    "x_encoded": {0: dim_swipe_seq},
    "decoder_in": {0: dim_char_seq}
}



aten_encode: ExportedProgram = export(Encode(model).eval(), (encoder_in,))
aten_decode: ExportedProgram = export(Decode(model).eval(), (decoder_in, encoded))

# edge_program: EdgeProgramManager = to_edge(
#     {"encode": aten_encode, "decode": aten_decode}
# )


# # edge_xnnpack: EdgeProgramManager = to_edge_transform_and_lower(
# #     exported_program,
# #     partitioner=[XnnpackPartitioner()],
# # )

# lowered_module: LoweredBackendModule = to_backend(
#     graph_module = edge_program, partitioner=[XnnpackPartitioner()]
# )


# for method in edge_program.methods:
#     print(f"Edge Dialect graph of {method}")
#     print(edge_program.exported_program(method))


edge_xnnpack: EdgeProgramManager = to_edge_transform_and_lower(
    {"encode": aten_encode, "decode": aten_decode},
    partitioner=[XnnpackPartitioner()],
)


exec_prog_xnnpack = edge_xnnpack.to_executorch()

with open("xnnpack_my_nearest_feats.pte", "wb") as file:
    exec_prog_xnnpack.write_to_file(file)

In [None]:
with open(voc_path, 'r', encoding='utf-8') as f:
    vocab = f.read().splitlines()

In [None]:
word_generator = BeamGenerator(model, char_tokenizer, 'cpu', vocab, max_token_id = 34)

In [None]:
(encoder_in, decoder_in), decoder_out_target = transform(RAW_DATASET_ITEM_EXAMPLE)
word_generator(encoder_in, max_steps_n=35)