# Infer

First, make paragraphs from lines (unwrap), then wrap to the limits of the network

In [35]:
import sys, os, datetime
import json
import torch
import random
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TranslationPipeline
from collections import defaultdict



In [36]:
import cdli
import languages
import oracc

In [7]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
src_langs = set(["akk", "sux"])
tgt_langs = set(["en"])

In [4]:
model_id = "praeclarum/cuneiform"
model_revision = "1ba74c8dcf6d1839b0a56589a53dfb5c20ca84f2"

In [5]:
batch_size = 8
device = "cpu"

## Load Existing Translations

In [6]:
output_json_path = "../data/ml_translations.json"

In [7]:
old_translations = json.loads(str(open(output_json_path, "rb").read(), "utf8"))
old_translations.keys()

dict_keys(['model_id', 'model_revision', 'akk_to_en', 'sux_to_en'])

In [8]:
new_translations = dict(old_translations)
new_translations.keys()

dict_keys(['model_id', 'model_revision', 'akk_to_en', 'sux_to_en'])

In [9]:
def sample_translations():
    for s_lang in src_langs:
        for t_lang in tgt_langs:
            st_key = f"{s_lang}_to_{t_lang}"
            if st_key in new_translations:
                translations = new_translations[st_key]
                print(len(translations), f"{st_key} translations")
                print([(x, translations[x]) for x in translations][:10])

In [10]:
sample_translations()

50000 sux_to_en translations
[('" [...] = %a [at]-ta', '..'), ('" [...] ~ [...] = %a [u4]-du-ru-u2', '..'), ('" [...] ~ [...] = %a szu-ub-tum', '..'), ('" [...] ~ [...] = %a {d}en-lil2', 'of Enlil'), ('" [...] ~ |SU.LU.SZE3|# = %a lu-up-pu-um', '..'), ('" [...] ~ |X.X.LAL| = %a na-du-u2-um', '..'), ('" [...] ~ |X.X.SI| = %a ku-ur-ku-u2-um', '..'), ('" [ba]-e# = %a za-zum', '..'), ('" [ga2-e] = %a [a]-na#?-ku-u2', 'I shall make secure'), ('" [za-e] = %a at#-ta', 'you')]
380004 akk_to_en translations
[('# A2    " x#        ~ A2', '..'), ('# GI    " x#-x#     ~ GI', '..'), ('# GI4   " [...]     ~ GI4', '..'), ('# a2    " [x]-a     ~ A2', '..'), ('# da    " da-a      ~ DA', '..'), ('# gi    " gi        ~ GI', '..'), ('# gilim " gi-li!-im ~ |GI%GI|', 'the reed, the reed, the ...'), ('# kab   " ka-ab#    ~ KAB', '..'), ('# ki2   " ki        ~ GI', '..'), ('# si2   " si        ~ ZI', '..')]


## Load the Model

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision, device=device)
model_max_length = tokenizer.model_max_length
model_max_length

Downloading tokenizer_config.json:   0%|          | 0.00/1.88k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

256

In [12]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, revision=model_revision, max_length=tokenizer.model_max_length)
model = model.to(device)
model

Downloading config.json:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/850M [00:00<?, ?B/s]

In [13]:
pipeline = TranslationPipeline(model=model, tokenizer=tokenizer, device=0)

In [14]:
print(pipeline("translate Akkadian to English: 1(disz){d}szul3-ma-nu-_sag man gal?_-u2 _man_ dan-nu _man kisz_"))
print(pipeline("translate Akkadian to English: ra-bi-isz e-pu-usz"))

## Load Transliterations to Translate

## Load ORACC

In [20]:
# oracc_dir = os.path.abspath(f"/Volumes/FrankDisk/oracc_zips")
oracc_dir = os.path.abspath(f"/home/fak/nn/Data/oracc_zips")
oracc_dir

'/home/fak/nn/Data/oracc_zips'

In [22]:
project_zips = oracc.get_all_project_zips(oracc_dir, verbose=False, tqdm=tqdm)

  0%|          | 0/139 [00:00<?, ?it/s]

In [23]:
print(len(project_zips))
project_zips[:3]

118


['/home/fak/nn/Data/oracc_zips/adsd.zip',
 '/home/fak/nn/Data/oracc_zips/adsd-adart1.zip',
 '/home/fak/nn/Data/oracc_zips/adsd-adart2.zip']

In [24]:
all_corpus_object_ids = oracc.get_all_corpus_object_ids(project_zips[:], tqdm=tqdm)

  0%|          | 0/118 [00:00<?, ?it/s]

In [26]:
all_corpus_object_ids[:10]

[('adsd/adart1', 'X102611'),
 ('adsd/adart1', 'X102612'),
 ('adsd/adart1', 'X102613'),
 ('adsd/adart1', 'X102620'),
 ('adsd/adart1', 'X102630'),
 ('adsd/adart1', 'X102640'),
 ('adsd/adart1', 'X102661'),
 ('adsd/adart1', 'X102662'),
 ('adsd/adart1', 'X102701'),
 ('adsd/adart1', 'X102702')]

In [28]:
len(all_corpus_object_ids)

26150

In [29]:
oracc_pub_ids_and_langs, transliterated_oracc_corpi = oracc.load_all_project_pub_ids(oracc_dir, tqdm=tqdm)

  0%|          | 0/139 [00:00<?, ?it/s]

Error: (<class 'json.decoder.JSONDecodeError'>, JSONDecodeError('Expecting value: line 1 column 1 (char 0)'), <traceback object at 0x7fadca2888c0>)
Error: (<class 'json.decoder.JSONDecodeError'>, JSONDecodeError('Expecting value: line 1 column 1 (char 0)'), <traceback object at 0x7fac6d1f2bc0>)
Error: (<class 'json.decoder.JSONDecodeError'>, JSONDecodeError('Expecting value: line 1 column 1 (char 0)'), <traceback object at 0x7fac6cde5a80>)
Error: (<class 'json.decoder.JSONDecodeError'>, JSONDecodeError('Expecting value: line 1 column 1 (char 0)'), <traceback object at 0x7fac6cde11c0>)
Error: (<class 'json.decoder.JSONDecodeError'>, JSONDecodeError('Expecting value: line 1 column 1 (char 0)'), <traceback object at 0x7fac6cc5a440>)
Error: (<class 'json.decoder.JSONDecodeError'>, JSONDecodeError('Expecting value: line 1 column 1 (char 0)'), <traceback object at 0x7fac6cdf0b00>)
Error: (<class 'json.decoder.JSONDecodeError'>, JSONDecodeError('Expecting value: line 1 column 1 (char 0)'), <t

In [71]:
transliterated_oracc_pub_ids = set(transliterated_oracc_corpi.keys())

In [31]:
print(len(oracc_pub_ids_and_langs), "oracc pubs")
print(len(transliterated_oracc_pub_ids), "oracc transliterated pubs")

133088 oracc pubs
21852 oracc transliterated pubs


In [32]:
reported_translated_ids = oracc.get_all_translated_object_ids(project_zips, tqdm)
print(len(reported_translated_ids), "reported translations")

  0%|          | 0/118 [00:00<?, ?it/s]

16553 reported translations


In [82]:
list(reported_translated_ids)[0]

('adsd/adart1', 'X102611')

In [79]:
list(transliterated_oracc_pub_ids)[0]

'Q007950'

In [92]:
transliterated_oracc_pids_and_oids = \
    [(x["corpus"]["project"], x["corpus"]["textid"]) for x in transliterated_oracc_corpi.values()]

In [93]:
transliterated_oracc_pids_and_oids[0]

('adsd/adart1', 'X103813')

In [94]:
for pid, oid in tqdm(transliterated_oracc_pids_and_oids[:]):
    tpath = oracc.download_object_translation(oracc_dir, pid, oid)

  0%|          | 0/21852 [00:00<?, ?it/s]

In [38]:
oracc_translated_ids = set(oracc.get_all_object_html_paths(oracc_dir).keys())
len(oracc_translated_ids), "oracc translated pubs"

(16218, 'oracc translated pubs')

In [76]:
oracc_transliterated_pubs = dict()
oracc_langs = defaultdict(lambda: 0)

for pid in tqdm(transliterated_oracc_pub_ids):
    p = oracc.get_object_id_pub(pid, oracc_dir)
    oracc_transliterated_pubs[pid] = p
    if p.language is not None:
        oracc_langs[p.language] += 1
    
oracc_langs = sorted([(x, oracc_langs[x]) for x in oracc_langs.keys()], key=lambda x:-x[1])
oracc_langs

  0%|          | 0/21852 [00:00<?, ?it/s]

KeyError: 'P514585'

## Load CDLI

In [43]:
cdli_pubs = cdli.get_atf()
len(publications)

Downloading https://github.com/cdli-gh/data/raw/master/cdliatf_unblocked.atf
Parsing atf


135256

In [46]:
publications[5000]

Publication('P005633', 'qpc', [TextArea('obverse', [], []), TextArea('column 1', [TextLine('1.', '1(N14) , GAR', {}), TextLine('2.', '1(N01@f) , KASZ~a SZE~a', {}), TextLine('3.', ', TU~b GU4 A', {})], [])])

In [47]:
def cdli_text_area_is_translated(pub, text_area, tgt_lang):
    for line in text_area.lines:
        if tgt_lang in line.languages:
            return True
    return False

def cdli_pub_is_translated(pub, tgt_lang):
    return any(x for x in pub.text_areas if cdli_text_area_is_translated(pub, x, tgt_lang))

cdli_translated_pubs = {x.id: x for x in cdli_pubs if cdli_pub_is_translated(x, "en")}
len(cdli_translated_pubs), "cdli translated pubs"

(5369, 'cdli translated pubs')

In [50]:
def cdli_lines_to_paragraphs(self, src_lang, tgt_lang, max_length=128):
    self.paragraphs = list()
    plen = 0
    for iline, line in enumerate(self.lines):
        line_len = len(line.text)
        if cdli.looks_like_li(line.text, src_lang):
            self.paragraphs.append(cdli.TextParagraph(iline, iline+1, "li"))
            plen = line_len
        else:
            if len(self.paragraphs) > 0 and self.paragraphs[-1].tag == "p" and plen + line_len < max_length:
                p = self.paragraphs[-1]
                p.end_line_index += 1
                plen += line_len
            else:
                self.paragraphs.append(cdli.TextParagraph(iline, iline+1))
                plen = line_len
    if any(l for l in self.lines if tgt_lang in l.languages):
        for p in self.paragraphs:
            lines = self.lines[p.start_line_index:p.end_line_index]
            tlines = [(x.languages[tgt_lang] if tgt_lang in x.languages else "") for x in lines]
            p.languages[tgt_lang] = languages.remove_extraneous_space(" ".join(tlines))
    return self.paragraphs

## Find just the transliterations

In [61]:
def get_all_transliterated_pubs(src_lang):
    transliterated_cdli_index = {x.id: x for x in cdli_pubs if x.language == src_lang}
    transliterated_cdli_ids = set(transliterated_cdli_index.keys())
    transliterated_oracc_index = {x: transliterated_oracc_corpi[x] for x in transliterated_oracc_corpi.keys() if transliterated_oracc_corpi[x]["lang"] == src_lang}
    transliterated_oracc_ids = set(transliterated_oracc_index.keys())
    all_ids = set(transliterated_cdli_ids)
    all_ids.union(transliterated_oracc_ids)
    return all_ids, transliterated_cdli_index, transliterated_oracc_index

all_transliterated_pubs = {
    "sux": get_all_transliterated_pubs("sux"),
    "akk": get_all_transliterated_pubs("akk")
}
print("sux", len(all_transliterated_pubs["sux"][0]))
print("akk", len(all_transliterated_pubs["akk"][0]))

sux 99819
akk 21953


## Build Need Translation List

In [70]:
def get_needed_translations(src_lang, encoding="ascii", tgt_lang="en"):
    pub_ids, cdli_pubs, oracc_pubs = all_transliterated_pubs[src_lang]
    corpi = ["cdli", "oracc"]
    for pub_id in tqdm(pub_ids):
        for corpus_index, corpus in enumerate(corpi):
            pubs = cdli_pubs if corpus_index == 0 else oracc_pubs
            if pub_id not in pubs:
                continue
            pub = pubs[pub_id]
            if corpus_index == 1:
                print(type(pub))
                print(pub)
            for a in pub.text_areas:
                if (corpus == "cdli") and len(a.lines) > 0:
                    a.lines_to_paragraphs(src_lang, tgt_lang)
    
get_needed_translations("sux")

  0%|          | 0/99819 [00:00<?, ?it/s]

<class 'dict'>
{'corpus': {'type': 'cdl', 'project': 'dcclt', 'source': 'http://oracc.org/dcclt', 'license': 'This data is released under the CC0 license', 'license-url': 'https://creativecommons.org/publicdomain/zero/1.0/', 'more-info': 'http://oracc.org/doc/opendata/', 'UTC-timestamp': '2022-06-26T00:39:08', 'textid': 'P010567', 'cdl': [{'node': 'c', 'type': 'text', 'id': 'P010567.U0', 'cdl': [{'node': 'd', 'subtype': 'tablet', 'type': 'object', 'ref': 'P010567.x41.1', 'label': 'x41'}, {'node': 'd', 'subtype': 'obverse', 'type': 'surface', 'ref': 'P010567.o.2', 'label': 'o'}, {'node': 'd', 'subtype': 'column 1o', 'type': 'column', 'ref': 'P010567.o.i.3', 'n': '1', 'label': 'o i'}, {'node': 'c', 'type': 'discourse', 'subtype': 'body', 'id': 'P010567.U1', 'cdl': [{'node': 'c', 'type': 'sentence', 'implicit': 'yes', 'id': 'P010567.U2', 'label': 'o i 1', 'cdl': [{'node': 'd', 'type': 'line-start', 'ref': 'P010567.4', 'n': '1', 'label': 'o i 1'}, {'node': 'l', 'frag': '{d}sud₃-kur-ra-diri

AttributeError: 'dict' object has no attribute 'text_areas'

In [None]:
def get_corpus_need_translations(corpus_pubs, src_lang, encoding="ascii", tgt_lang="en"):
    srcs = set()
    translations = []

    for corpus, pubs in corpus_pubs:
        pubs = [pubs[x] for x in pubs.keys() if pubs[x].language == src_lang]
        print(f"{corpus} {src_lang} with {len(pubs)} translated publications")
        longest_line_len = 0
        for pub in tqdm(pubs):
            for a in pub.text_areas:
                if (corpus == "cdli") and len(a.lines) > 0:
                    a.lines_to_paragraphs(src_lang, tgt_lang)
                for p in a.paragraphs:
                    if tgt_lang in p.languages:
                        src_lines = [x.text for x in a.lines[p.start_line_index:p.end_line_index]]
                        src = " ".join(src_lines)
                        src = languages.remove_blanks(src)
                        src = languages.underline_sign_names(src)
                        src = languages.dashes_to_dots(src)
                        
                        src = languages.remove_extraneous_space(src)
                        tgt = p.languages[tgt_lang]
                        tgt = tgt.replace("[", "").replace("]", "").replace("(", "").replace(")", "")
                        tgt = tgt.replace("\n", " ").replace("\t", " ")
                        tgt = languages.remove_extraneous_space(tgt)
                        if len(src) > 0 and languages.target_ok(tgt) and src not in srcs:
                            line_len = len(src) + len(tgt)
                            longest_line_len = max(line_len, longest_line_len)
                            out_line = json.dumps({src_lang:src,tgt_lang:tgt})
                            translations.append(out_line)
                            srcs.add(src)
                            
corpi = [("cdli", cdli_pubs), ("oracc", oracc_translated_pubs)]
    
output_translations(corpi, "akk")
output_translations(corpi, "sux")


In [17]:
def get_need_translation():
    pubs_to_translate = []
    for src_lang in src_langs:
        for tgt_lang in tgt_langs:
            for p in publications:
                if p.language == src_lang and any(tgt_lang in line.languages for a in p.text_areas for line in a.lines):
                    pubs_to_translate.append((src_lang, tgt_lang, p))


    batch = []
    need_translation = [batch]

    for (src_lang, tgt_lang, p) in pubs_to_translate:
        st_key = f"{src_lang}_to_{tgt_lang}"
        if st_key not in new_translations:
            translations = {}
            new_translations[st_key] = translations
        else:
            translations = new_translations[st_key]
        for a in p.text_areas:
            for l in a.lines:
                if len(l.text) > 0 and l.text not in translations:
                    batch.append((src_lang, tgt_lang, l.text))
                    if len(batch) == batch_size:
                        batch = []
                        need_translation.append(batch)
    return need_translation

need_translation = get_need_translation()

print(len(need_translation), "need translation")
print(new_translations.keys())

for batch in tqdm(need_translation):
    qs = [f"translate {languages.all_languages[src_lang]} to {languages.all_languages[tgt_lang]}: " + text for src_lang, tgt_lang, text in batch]
    r = pipeline(qs)
    for i, (src_lang, tgt_lang, s) in enumerate(batch):
        st_key = f"{src_lang}_to_{tgt_lang}"
        translations = new_translations[st_key]
        t = r[i]['translation_text']
#         print(i, s, t)
        translations[s] = t

1 need translation
dict_keys(['model_id', 'model_revision', 'akk_to_en', 'sux_to_en'])


  0%|          | 0/1 [00:00<?, ?it/s]

In [18]:
sample_translations()

50000 sux_to_en translations
[('" [...] = %a [at]-ta', '..'), ('" [...] ~ [...] = %a [u4]-du-ru-u2', '..'), ('" [...] ~ [...] = %a szu-ub-tum', '..'), ('" [...] ~ [...] = %a {d}en-lil2', 'of Enlil'), ('" [...] ~ |SU.LU.SZE3|# = %a lu-up-pu-um', '..'), ('" [...] ~ |X.X.LAL| = %a na-du-u2-um', '..'), ('" [...] ~ |X.X.SI| = %a ku-ur-ku-u2-um', '..'), ('" [ba]-e# = %a za-zum', '..'), ('" [ga2-e] = %a [a]-na#?-ku-u2', 'I shall make secure'), ('" [za-e] = %a at#-ta', 'you')]
380004 akk_to_en translations
[('# A2    " x#        ~ A2', '..'), ('# GI    " x#-x#     ~ GI', '..'), ('# GI4   " [...]     ~ GI4', '..'), ('# a2    " [x]-a     ~ A2', '..'), ('# da    " da-a      ~ DA', '..'), ('# gi    " gi        ~ GI', '..'), ('# gilim " gi-li!-im ~ |GI%GI|', 'the reed, the reed, the ...'), ('# kab   " ka-ab#    ~ KAB', '..'), ('# ki2   " ki        ~ GI', '..'), ('# si2   " si        ~ ZI', '..')]


## Save the Translations

In [19]:
def write_translations(f):
    f.write("{\n")
    f.write(f"\"model_id\":\"{model_id}\",\n")
    f.write(f"\"model_revision\":\"{model_revision}\"")
    for st_key in sorted([x for x in new_translations if not x.startswith("model_")]):
        f.write(f",\n\"{st_key}\":{{\n")
        translations = new_translations[st_key]
        head = ""
        for s in sorted(list(translations.keys())):
            f.write(head)
            f.write(json.dumps(s))
            f.write(": ")
            f.write(json.dumps(translations[s]))
            head = ",\n"
        f.write("}")
    f.write("}\n")
    
# write_translations(sys.stdout)

In [20]:
with open(output_json_path, "wt") as f:
    write_translations(f)

In [21]:
!ls -al ../data

total 63108
drwxrwxr-x 2 fak fak     4096 Jul 21 19:46 .
drwxrwxr-x 6 fak fak     4096 Jul 20 02:50 ..
-rw-rw-r-- 1 fak fak 29968338 Jul 21 19:47 ml_translations.json
-rw-rw-r-- 1 fak fak 34640969 Jul 19 18:54 translations.jsonl


In [22]:
new_json = json.load(open(output_json_path, "rt"))
print(len(new_json["akk_to_en"]))
print(len(new_json["sux_to_en"]))

380004
50000
