# Infer

In [1]:
import sys, os, datetime
import json
import torch
import random
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TranslationPipeline
import cdli
import languages

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
src_langs = set(["akk", "sux"])
tgt_langs = set(["en"])

In [4]:
model_id = "praeclarum/cuneiform"
model_revision = "1ba74c8dcf6d1839b0a56589a53dfb5c20ca84f2"

In [5]:
batch_size = 8
device = "cuda"

## Load Existing Translations

In [6]:
output_json_path = "../data/ml_translations.json"

In [7]:
old_translations = json.loads(str(open(output_json_path, "rb").read(), "utf8"))
old_translations.keys()

dict_keys(['model_id', 'model_revision', 'akk_to_en', 'sux_to_en'])

In [8]:
new_translations = dict(old_translations)
new_translations.keys()

dict_keys(['model_id', 'model_revision', 'akk_to_en', 'sux_to_en'])

In [9]:
def sample_translations():
    for s_lang in src_langs:
        for t_lang in tgt_langs:
            st_key = f"{s_lang}_to_{t_lang}"
            if st_key in new_translations:
                translations = new_translations[st_key]
                print(len(translations), f"{st_key} translations")
                print([(x, translations[x]) for x in translations][:10])

In [10]:
sample_translations()

50000 sux_to_en translations
[('" [...] = %a [at]-ta', '..'), ('" [...] ~ [...] = %a [u4]-du-ru-u2', '..'), ('" [...] ~ [...] = %a szu-ub-tum', '..'), ('" [...] ~ [...] = %a {d}en-lil2', 'of Enlil'), ('" [...] ~ |SU.LU.SZE3|# = %a lu-up-pu-um', '..'), ('" [...] ~ |X.X.LAL| = %a na-du-u2-um', '..'), ('" [...] ~ |X.X.SI| = %a ku-ur-ku-u2-um', '..'), ('" [ba]-e# = %a za-zum', '..'), ('" [ga2-e] = %a [a]-na#?-ku-u2', 'I shall make secure'), ('" [za-e] = %a at#-ta', 'you')]
380004 akk_to_en translations
[('# A2    " x#        ~ A2', '..'), ('# GI    " x#-x#     ~ GI', '..'), ('# GI4   " [...]     ~ GI4', '..'), ('# a2    " [x]-a     ~ A2', '..'), ('# da    " da-a      ~ DA', '..'), ('# gi    " gi        ~ GI', '..'), ('# gilim " gi-li!-im ~ |GI%GI|', 'the reed, the reed, the ...'), ('# kab   " ka-ab#    ~ KAB', '..'), ('# ki2   " ki        ~ GI', '..'), ('# si2   " si        ~ ZI', '..')]


## Load the Model

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision, device=device)
model_max_length = tokenizer.model_max_length
model_max_length

256

In [12]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, revision=model_revision, max_length=tokenizer.model_max_length)
model = model.to(device)
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dr

In [13]:
pipeline = TranslationPipeline(model=model, tokenizer=tokenizer, device=0)

In [14]:
print(pipeline("translate Akkadian to English: 1(disz){d}szul3-ma-nu-_sag man gal?_-u2 _man_ dan-nu _man kisz_"))
print(pipeline("translate Akkadian to English: ra-bi-isz e-pu-usz"))

[{'translation_text': 'Shulmanu-sag, great king, strong king, king of the world'}]
[{'translation_text': 'I built in a grand manner'}]


## Load Transliterations to Translate

In [15]:
publications = cdli.get_atf()
len(publications)

Downloading https://github.com/cdli-gh/data/raw/master/cdliatf_unblocked.atf
Parsing atf


134710

In [16]:
publications[5000]

Publication('P005633', 'qpc', [TextArea('obverse', []), TextArea('column 1', [TextLine('1.', '1(N14) , GAR', {}), TextLine('2.', '1(N01@f) , KASZ~a SZE~a', {}), TextLine('3.', ', TU~b GU4 A', {})])])

In [17]:
def get_need_translation():
    pubs_to_translate = []
    for src_lang in src_langs:
        for tgt_lang in tgt_langs:
            for p in publications:
                if p.language == src_lang and any(tgt_lang in line.languages for a in p.text_areas for line in a.lines):
                    pubs_to_translate.append((src_lang, tgt_lang, p))


    batch = []
    need_translation = [batch]

    for (src_lang, tgt_lang, p) in pubs_to_translate:
        st_key = f"{src_lang}_to_{tgt_lang}"
        if st_key not in new_translations:
            translations = {}
            new_translations[st_key] = translations
        else:
            translations = new_translations[st_key]
        for a in p.text_areas:
            for l in a.lines:
                if len(l.text) > 0 and l.text not in translations:
                    batch.append((src_lang, tgt_lang, l.text))
                    if len(batch) == batch_size:
                        batch = []
                        need_translation.append(batch)
    return need_translation

need_translation = get_need_translation()

print(len(need_translation), "need translation")
print(new_translations.keys())

for batch in tqdm(need_translation):
    qs = [f"translate {languages.all_languages[src_lang]} to {languages.all_languages[tgt_lang]}: " + text for src_lang, tgt_lang, text in batch]
    r = pipeline(qs)
    for i, (src_lang, tgt_lang, s) in enumerate(batch):
        st_key = f"{src_lang}_to_{tgt_lang}"
        translations = new_translations[st_key]
        t = r[i]['translation_text']
#         print(i, s, t)
        translations[s] = t

1 need translation
dict_keys(['model_id', 'model_revision', 'akk_to_en', 'sux_to_en'])


  0%|          | 0/1 [00:00<?, ?it/s]

In [18]:
sample_translations()

50000 sux_to_en translations
[('" [...] = %a [at]-ta', '..'), ('" [...] ~ [...] = %a [u4]-du-ru-u2', '..'), ('" [...] ~ [...] = %a szu-ub-tum', '..'), ('" [...] ~ [...] = %a {d}en-lil2', 'of Enlil'), ('" [...] ~ |SU.LU.SZE3|# = %a lu-up-pu-um', '..'), ('" [...] ~ |X.X.LAL| = %a na-du-u2-um', '..'), ('" [...] ~ |X.X.SI| = %a ku-ur-ku-u2-um', '..'), ('" [ba]-e# = %a za-zum', '..'), ('" [ga2-e] = %a [a]-na#?-ku-u2', 'I shall make secure'), ('" [za-e] = %a at#-ta', 'you')]
380004 akk_to_en translations
[('# A2    " x#        ~ A2', '..'), ('# GI    " x#-x#     ~ GI', '..'), ('# GI4   " [...]     ~ GI4', '..'), ('# a2    " [x]-a     ~ A2', '..'), ('# da    " da-a      ~ DA', '..'), ('# gi    " gi        ~ GI', '..'), ('# gilim " gi-li!-im ~ |GI%GI|', 'the reed, the reed, the ...'), ('# kab   " ka-ab#    ~ KAB', '..'), ('# ki2   " ki        ~ GI', '..'), ('# si2   " si        ~ ZI', '..')]


## Save the Translations

In [19]:
def write_translations(f):
    f.write("{\n")
    f.write(f"\"model_id\":\"{model_id}\",\n")
    f.write(f"\"model_revision\":\"{model_revision}\"")
    for st_key in sorted([x for x in new_translations if not x.startswith("model_")]):
        f.write(f",\n\"{st_key}\":{{\n")
        translations = new_translations[st_key]
        head = ""
        for s in sorted(list(translations.keys())):
            f.write(head)
            f.write(json.dumps(s))
            f.write(": ")
            f.write(json.dumps(translations[s]))
            head = ",\n"
        f.write("}")
    f.write("}\n")
    
# write_translations(sys.stdout)

In [20]:
with open(output_json_path, "wt") as f:
    write_translations(f)

In [21]:
!ls -al ../data

total 63108
drwxrwxr-x 2 fak fak     4096 Jul 21 19:46 .
drwxrwxr-x 6 fak fak     4096 Jul 20 02:50 ..
-rw-rw-r-- 1 fak fak 29968338 Jul 21 19:47 ml_translations.json
-rw-rw-r-- 1 fak fak 34640969 Jul 19 18:54 translations.jsonl


In [22]:
new_json = json.load(open(output_json_path, "rt"))
print(len(new_json["akk_to_en"]))
print(len(new_json["sux_to_en"]))

380004
50000
