In [28]:
import transformers
from transformers import ByT5Tokenizer, T5ForConditionalGeneration, T5Config
import torch
import seaborn as sns
from tqdm import tqdm
import numpy as np
import math

from src.myt5_tokenizer import MyT5Tokenizer

## Sentence Sample

In [29]:
sentences = ["የአውሮፕላን አብራሪው የአየር ሀይል መሪ ዲሎክሪት ፓታቪ ሆኖ ተለይቷል።",
             "The pilot was identified as Squadron Leader Dilokrit Pattavee.",
             "Der Pilot wurde als Staffelführer Dilokrit Pattavee identifiziert.",
             "Išsiaiškinta, kad pilotas – eskadrilės vadas Dilokritas Pattavee.",
             "涉事飞行员是空军中队长迪罗里·帕塔维 (Dilokrit Pattavee)。",
             "Pilota zidentyfikowano jako Dilokrita Pattavee, dowódcę eskadry."]

sentences_2 =  ['We now have 4-month-old mice that are non-diabetic that used to be diabetic," he added.',
                "አሁን የስኳር በሽተኛ ያልነበሩ አሁን ግን የሆኑ የ4-ወር-ዕድሜ ያላቸው አይጦች አሉን፣ አለ። ",
                "„Mamy teraz myszy w wieku 4 miesięcy, które miały cukrzycę, ale zostały z niej wyleczone” – dodał.",
                "「我們有 4 個月大曾經罹患糖尿病老鼠現在沒有糖尿病了」他補充道。"]

prefixes = [sent[:5] for sent in sentences]
sufixes = [sent[5:] for sent in sentences]

## Load Byte and Morpholofical Models

In [30]:
byt5_small = T5ForConditionalGeneration.from_pretrained("hf_checkpoints/byt5_small_250000" ,use_safetensors=True)
by_tokenizer = ByT5Tokenizer()

In [31]:
myt5_small = T5ForConditionalGeneration.from_pretrained("hf_checkpoints/myt5_small_250000" ,use_safetensors=True)
my_tokenizer = MyT5Tokenizer(decompose_map="byte_maps/decompose_map.json",
                               merge_map="byte_maps/merge_map.json")


# TODOs
- check how loss is computed
- run on Flores
- compare multiple models

## Evaluate NLL on sentence levels

In [None]:
def evaluate_texts(text_dataset, model, tokenizer, batch_size=32, context=0):

    sentence_nlls = []
    sentence_bpbs = []
    sentence_compressions = []
    context = min(abs(context), 1.0)

    for i in tqdm(range(0, len(text_dataset), batch_size)):
        batch = text_dataset[i:i+batch_size]
        batch_contexts = [math.floor(context * len(sent.split(" "))) for sent in batch]

        batch_prefixes = [" ".join(sent.split(" ")[:bc]) + " " for sent, bc in zip(batch,batch_contexts) ]
        batch_suffixes = [" ".join(sent.split(" ")[bc:]) for sent, bc in zip(batch,batch_contexts) ]
        byte_lengths = torch.tensor([len(suf.encode("utf-8")) +1 for suf in batch_suffixes])
        if len(batch_prefixes) == 0:
            continue

        inputs = tokenizer(
            batch_prefixes, padding="longest", return_tensors="pt"
        )
        targets = tokenizer(
            batch_suffixes, padding="longest", return_tensors="pt"
        )

        

        outputs = model(**inputs, labels=targets.input_ids)
        
        logits = outputs.logits
        logits = torch.nn.functional.log_softmax(logits, dim=-1)

        target_labels = targets.input_ids.unsqueeze(-1)
        mask = targets.attention_mask

        target_logits = torch.gather(logits, -1, target_labels).squeeze(-1)

        batch_nlls = -torch.sum(mask * target_logits, axis=-1)
        batch_bpbs = torch.exp(-torch.sum(mask * target_logits , axis=-1)/byte_lengths)
        batch_compressions = torch.sum(mask, axis=-1) / byte_lengths
        print(outputs.loss * batch_compressions)
        sentence_nlls.extend(batch_nlls.tolist())
        sentence_bpbs.extend(batch_bpbs.tolist())
        sentence_compressions.extend(batch_compressions.tolist())


    return sentence_nlls, sentence_bpbs, sentence_compressions

In [None]:
# def evaluate_texts(text_dataset, model, tokenizer, batch_size=32):

#     sentence_nlls = []
#     sentence_bpbs = []
#     #context = min(abs(context), 1.0)

#     for i in tqdm(range(0, len(text_dataset), batch_size)):
#         batch = text_dataset[i:i+batch_size]
#         # batch_contexts = [math.floor(context * len(sent)) for sent in batch]

#         # batch_prefixes = [sent[:bc] for sent, bc in zip(batch,batch_contexts) ]
#         # batch_suffixes = [sent[bc:] for sent, bc in zip(batch,batch_contexts) ]
#         byte_lengths = torch.tensor([len(sent.encode("utf-8")) for sent in batch])
#         # if len(batch_prefixes) == 0:
#         #     continue

#         inputs = tokenizer(
#             batch, padding="longest", return_tensors="pt", add_special_tokens = True
#         )
#         # targets = tokenizer(
#         #     batch_suffixes, padding="longest", return_tensors="pt"
#         # )
    
#         targets = inputs.input_ids.clone()

#         outputs = model(**inputs, labels=targets)
#         print(outputs.loss)
        
#         logits = outputs.logits[:, :-1, :]
#         print(logits.shape)
#         logits = torch.nn.functional.log_softmax(logits, dim=-1)
#         #probabilities = 
#         target_labels = targets[:,1:].unsqueeze(-1)
#         mask = inputs.attention_mask[:,:-1]

#         target_logits = torch.gather(logits, -1, target_labels).squeeze(-1)

#         #target_logits = torch.gather(logits, -1, target_labels).squeeze(-1)

#         batch_nlls = -torch.sum(mask * target_logits, axis=-1)
#         batch_bpbs = torch.exp(-torch.sum(mask * target_logits , axis=-1)/byte_lengths)

#         sentence_nlls.extend(batch_nlls.tolist())
#         sentence_bpbs.extend(batch_bpbs.tolist())

#     return outputs, sentence_nlls, sentence_bpbs

In [None]:
sentences = ["የአውሮፕላን አብራሪው የአየር ሀይል መሪ ዲሎክሪት ፓታቪ ሆኖ ተለይቷል።",
             "The pilot was identified as Squadron Leader Dilokrit Pattavee.",
             "Der Pilot wurde als Staffelführer Dilokrit Pattavee identifiziert.",
             "Išsiaiškinta, kad pilotas – eskadrilės vadas Dilokritas Pattavee.",
             "涉事飞行员是空军中队长迪罗里·帕塔维 (Dilokrit Pattavee)。",
             "Pilota zidentyfikowano jako Dilokrita Pattavee, dowódcę eskadry."]

In [None]:
sentences[0]

In [None]:
evaluate_texts([sentences[0]], myt5_small, my_tokenizer, context=0.0)

In [None]:
evaluate_texts(sentences, myt5_small, my_tokenizer, context=0.0)

In [None]:
evaluate_texts(sentences, byt5_small, by_tokenizer, context=0.75)

In [None]:
evaluate_texts(sentences_2, myt5_small, my_tokenizer, context=0.75)

In [None]:
evaluate_texts(sentences_2, byt5_small, by_tokenizer, context=0.75)

In [None]:
evaluate_texts([sentences_2[3]], byt5_small, by_tokenizer, context=0.75)

In [None]:
evaluate_texts(sentences, byt5_small, by_tokenizer)

## Load Flores Dataset for some language

In [None]:
# load flores dataset for languages: en, es, pt, fr, it, ro, pl, mt, ja, zh, ko, te, ta

# language sample
languages = ['en', 'de', 'fr', 'ru','pl','ja', 'vi', 'ko','hy', 'kk', 'el', 'ta','te','am', 'sn', 'mt', 'sm', 'st']

palette = sns.color_palette("viridis", len(languages))
languages_colors = {lang: col for lang, col in zip(languages, palette)}
nice_colors = [ ]



# use the code from above to get the flores200 languages
# Latin / Corsican / Hawaiian language not in Flores
languages_flores = {'en': 'eng_Latn', 'ceb': 'ceb_Latn', 'de': 'deu_Latn', 'sv': 'swe_Latn', 'fr': 'fra_Latn', 'nl': 'nld_Latn', 'ru': 'rus_Cyrl', 'es': 'spa_Latn',
                    'it': 'ita_Latn', 'pl': 'pol_Latn', 'ja': 'jpn_Jpan', 'zh': 'zho_Hans', 'uk': 'ukr_Cyrl', 'vi': 'vie_Latn', 'ar': 'arb_Arab',
                    'pt': 'por_Latn', 'fa': 'pes_Arab', 'ca': 'cat_Latn', 'sr': 'srp_Cyrl', 'id': 'ind_Latn', 'ko': 'kor_Hang', 'no': 'nob_Latn',
                    'fi': 'fin_Latn', 'tr': 'tur_Latn', 'cs': 'ces_Latn', 'hu': 'hun_Latn', 'ro': 'ron_Latn', 'eu': 'eus_Latn', 'ms': 'zsm_Latn',
                    'eo': 'epo_Latn', 'he': 'heb_Hebr', 'hy': 'hye_Armn', 'da': 'dan_Latn', 'bg': 'bul_Cyrl', 'cy': 'cym_Latn', 'sk': 'slk_Latn',
                    'uz': 'uzn_Latn', 'et': 'est_Latn', 'be': 'bel_Cyrl', 'kk': 'kaz_Cyrl', 'el': 'ell_Grek', 'lt': 'lit_Latn', 'gl': 'glg_Latn',
                    'ur': 'urd_Arab', 'az': 'azj_Latn', 'sl': 'slv_Latn', 'ka': 'kat_Geor', 'hi': 'hin_Deva', 'th': 'tha_Thai', 'ta': 'tam_Taml',
                    'bn': 'ben_Beng', 'mk': 'mkd_Cyrl',  'lv': 'lvs_Latn', 'af': 'afr_Latn', 'tg': 'tgk_Cyrl', 'my': 'mya_Mymr',
                    'mg': 'plt_Latn', 'sq': 'als_Latn', 'mr': 'mar_Deva', 'te': 'tel_Telu', 'ml': 'mal_Mlym', 'ky': 'kir_Cyrl', 'sw': 'swh_Latn',
                    'jv': 'jav_Latn', 'ht': 'hat_Latn', 'lb': 'ltz_Latn', 'su': 'sun_Latn', 'ku': 'kmr_Latn', 'ga': 'gle_Latn', 'is': 'isl_Latn',
                    'fy': 'fao_Latn', 'pa': 'pan_Guru', 'yo': 'yor_Latn', 'ne': 'npi_Deva', 'ha': 'hau_Latn', 'kn': 'kan_Knda', 'gu': 'guj_Gujr',
                    'mn': 'khk_Cyrl', 'ig': 'ibo_Latn', 'si': 'sin_Sinh', 'ps': 'pbt_Arab', 'gd': 'gla_Latn', 'sd': 'snd_Arab', 'yi': 'ydd_Hebr',
                    'am': 'amh_Ethi', 'sn': 'sna_Latn', 'zu': 'zul_Latn', 'km': 'khm_Khmr', 'so': 'som_Latn', 'mi': 'mri_Latn',
                    'mt': 'mlt_Latn', 'lo': 'lao_Laoo', 'xh': 'xho_Latn', 'sm': 'smo_Latn', 'ny': 'nya_Latn', 'st': 'sot_Latn'}

flores = {}

for lang in languages:
    with open(f'flores200_dataset/devtest/{languages_flores[lang]}.devtest', 'r') as f:
        flores[lang] = f.read().splitlines()[:50]

## Compute results

In [None]:
bpbs = {"my_small": {}, "by_small": {}, "mixed_small": {}}
nlls = {"my_small": {}, "by_small": {}, "mixed_small": {}}
comps = {"my_small": {}, "by_small": {}, "mixed_small": {}}

In [None]:
bpbs_75 = {"my_small": {}, "by_small": {}, "mixed_small": {}}
nlls_75 = {"my_small": {}, "by_small": {}, "mixed_small": {}}
comps_75 = {"my_small": {}, "by_small": {}, "mixed_small": {}}

In [None]:
lang = "en"
model = "my_small"

bpbs_75[model][lang], nlls_75[model][lang], comps_75[model][lang] = evaluate_texts(flores[lang], myt5_small, my_tokenizer, context=0.75)

In [None]:
lang = "en"
model = "by_small"

bpbs_75[model][lang], nlls_75[model][lang], comps_75[model][lang] = evaluate_texts(flores[lang], byt5_small, by_tokenizer, context=0.75)

In [None]:
np.median(bpbs["my_small"]["en"])

In [None]:
np.mean(nlls_75["my_small"]["en"])

In [None]:
np.mean(nlls_75["by_small"]["en"])