# Infer

In [1]:
import sys, os, datetime
import json
import torch
import random
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TranslationPipeline
import cdli
import languages

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
src_langs = set(["akk", "sux"])
tgt_langs = set(["en"])

In [4]:
model_id = "praeclarum/cuneiform"
model_revision = "1ba74c8dcf6d1839b0a56589a53dfb5c20ca84f2"

In [5]:
batch_size = 8
device = "cuda"

## Load Existing Translations

In [34]:
output_json_path = "../data/ml_translations.json"

In [35]:
old_translations = json.loads(str(open(output_json_path, "rb").read(), "utf8"))
old_translations.keys()

dict_keys(['model_id', 'model_revision', 'akk_to_en'])

In [36]:
new_translations = dict(old_translations)
new_translations.keys()

dict_keys(['model_id', 'model_revision', 'akk_to_en'])

In [37]:
def sample_translations():
    for s_lang in src_langs:
        for t_lang in tgt_langs:
            st_key = f"{s_lang}_to_{t_lang}"
            if st_key in new_translations:
                translations = new_translations[st_key]
                print(len(translations), f"{st_key} translations")
                print([(x, translations[x]) for x in translations][:10])

In [38]:
sample_translations()

380004 akk_to_en translations
[('_1(u) 2(asz) gur sze_ hu-bu-ta-tum', '12 gur barley of the new-year'), ('u2-sze-ti-iq-ma', 'he caused to be evicted, and'), ('_1(asz) gur 1(barig) 4(ban2) masz2_ u2-s,a-ab', '1 gur 1 barig 4 ban2 of goat hair will be removed'), ('_ki_ {d}utu', 'with Shamash'), ('u3 gi-da-nu-um', 'and the scepter'), ('{disz}{d}be-el-gesztu-a-bu-szu', 'Bl-geshtu-abusshu'), ('_dumu_ i3-li2-ma-ra-x', 'son of Ili-mr...'), ('_szu ba#-an-ti#_', 'received'), ('a-na _masz-gan2_-nim#', 'for the harvest'), ('sze-am _i3-ag2-e_', 'the barley he shall weigh out')]


## Load the Model

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision, device=device)
model_max_length = tokenizer.model_max_length
model_max_length

256

In [12]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, revision=model_revision, max_length=tokenizer.model_max_length)
model = model.to(device)
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dr

In [13]:
pipeline = TranslationPipeline(model=model, tokenizer=tokenizer, device=0)

In [14]:
print(pipeline("translate Akkadian to English: 1(disz){d}szul3-ma-nu-_sag man gal?_-u2 _man_ dan-nu _man kisz_"))
print(pipeline("translate Akkadian to English: ra-bi-isz e-pu-usz"))

[{'translation_text': 'Shulmanu-sag, great king, strong king, king of the world'}]
[{'translation_text': 'I built in a grand manner'}]


## Load Transliterations to Translate

In [15]:
publications = cdli.get_atf()
len(publications)

Downloading https://github.com/cdli-gh/data/raw/master/cdliatf_unblocked.atf
Parsing atf


134710

In [16]:
publications[5000]

Publication('P005633', 'qpc', [TextArea('obverse', []), TextArea('column 1', [TextLine('1.', '1(N14) , GAR', {}), TextLine('2.', '1(N01@f) , KASZ~a SZE~a', {}), TextLine('3.', ', TU~b GU4 A', {})])])

In [None]:
def get_need_translation():
    pubs_to_translate = []
    for src_lang in src_langs:
        for tgt_lang in tgt_langs:
            for p in publications:
                if p.language == src_lang and any(tgt_lang in line.languages for a in p.text_areas for line in a.lines):
                    pubs_to_translate.append((src_lang, tgt_lang, p))


    batch = []
    need_translation = [batch]

    for (src_lang, tgt_lang, p) in pubs_to_translate:
        st_key = f"{src_lang}_to_{tgt_lang}"
        if st_key not in new_translations:
            translations = {}
            new_translations[st_key] = translations
        else:
            translations = new_translations[st_key]
        for a in p.text_areas:
            for l in a.lines:
                if len(l.text) > 0 and l.text not in translations:
                    batch.append((src_lang, tgt_lang, l.text))
                    if len(batch) == batch_size:
                        batch = []
                        need_translation.append(batch)
    return need_translation

need_translation = get_need_translation()

print(len(need_translation), "need translation")
print(new_translations.keys())

for batch in tqdm(need_translation):
    qs = [f"translate {languages.all_languages[src_lang]} to {languages.all_languages[tgt_lang]}: " + text for src_lang, tgt_lang, text in batch]
    r = pipeline(qs)
    for i, (src_lang, tgt_lang, s) in enumerate(batch):
        st_key = f"{src_lang}_to_{tgt_lang}"
        translations = new_translations[st_key]
        t = r[i]['translation_text']
#         print(i, s, t)
        translations[s] = t

9056 need translation
dict_keys(['model_id', 'model_revision', 'akk_to_en', 'sux_to_en'])


  0%|          | 0/9056 [00:00<?, ?it/s]

In [52]:
sample_translations()

40 sux_to_en translations
[('2(u@c) 2(asz@c) uruda ma-na', '22 mana copper'), ('sa10 GAN2', 'exchange value of the field'), ('1(esze3@c) 2(iku@c) GAN2-bi', '1 eshe3 2 iku its field'), ('1(u@c) 6(asz@c) uruda ma-na', '16 mana copper'), ('nig2-diri', 'additional payment'), ('2(asz@c) sze lid2-ga', '2 lidga (measures of) barley'), ('nig2-ba', 'gift'), ('1(asz@c) tug2 szu', '1 “shu” garment'), ('1(asz@c) me-gal2 tug2', '1 “big” garment'), ('tug2', 'textile')]
380004 akk_to_en translations
[('_1(u) 2(asz) gur sze_ hu-bu-ta-tum', '12 gur barley of the new-year'), ('u2-sze-ti-iq-ma', 'he caused to be evicted, and'), ('_1(asz) gur 1(barig) 4(ban2) masz2_ u2-s,a-ab', '1 gur 1 barig 4 ban2 of goat hair will be removed'), ('_ki_ {d}utu', 'with Shamash'), ('u3 gi-da-nu-um', 'and the scepter'), ('{disz}{d}be-el-gesztu-a-bu-szu', 'Bl-geshtu-abusshu'), ('_dumu_ i3-li2-ma-ra-x', 'son of Ili-mr...'), ('_szu ba#-an-ti#_', 'received'), ('a-na _masz-gan2_-nim#', 'for the harvest'), ('sze-am _i3-ag2-e_', '

## Save the Translations

In [49]:
output_json = json.dumps(new_translations)

In [50]:
with open(output_json_path, "wb") as f:
    f.write(bytes(output_json, "utf8"))