In [49]:
# imports
from datetime import datetime
from functools import wraps
import glob
import itertools
import logging
import os
import random
import time
import tomli
import tracemalloc
from typing import Iterator, Dict, List, Generator, Tuple, Optional, Any, TextIO

# import codecarbon
import datasets
import numpy as np
import torch
from torch.utils.data import DataLoader
import transformers

# TODO check these imports
from transformers import AutoTokenizer
from tokenizers.normalizers import Normalizer
from tokenizers import Tokenizer, Regex, NormalizedString, PreTokenizedString

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

def file_gen(path : str) -> Generator[TextIO, None, None]:
    """ Yields file(s) from a path, where path can be file, dir or glob """

    if os.path.isfile(path):
        with open(path, 'r', encoding='utf-8') as file:
            yield file
    elif os.path.isdir(path):
        for filename in os.listdir(path):
            with open(os.path.join(path, filename), 'r', encoding='utf-8') as file:
                yield file
    else:
        for filename in glob.glob(path):
            with open(filename, 'r', encoding='utf-8') as file:
                yield file

def load_dtaeval_as_dataset(path:str) -> datasets.Dataset:
    """ 
    Load the file(s) under `path` into a datasets.Dataset with columns "orig" and "norm"

    If `path` is a directory name, 
    """
    
    docs = [load_tsv(file, keep_sentences=True) for file in file_gen(path)]
    docs_sent_joined = [[
            [" ".join(sent) for sent in column] for column in doc
        ] for doc in docs]
        
    all_sents_orig, all_sents_norm = [], []
    for doc_orig, doc_norm in docs_sent_joined:
        all_sents_orig.extend([sent for sent in doc_orig])
        all_sents_norm.extend([sent for sent in doc_norm])

    return datasets.Dataset.from_dict({"orig" : all_sents_orig, "norm" : all_sents_norm})

def load_tsv(file_obj, keep_sentences=True):
    """
    Load a corpus in a tab-separated plain text file into lists

    `keep_sentences` : if set to True, empty lines are interpreted
    as sentence breaks. Consecutive empty lines are ignored.
    """

    line = file_obj.readline()
    # Read upto first non-empty line
    while line.isspace():
        line = f.readline()
    # Number of columns in text file
    n_columns = line.strip().count("\t") + 1
    # Initial empty columns with one empty sentence inside
    columns = [[[]] for i in range(n_columns)]
    # Read file
    line_cnt = 0
    sent_cnt = 0
    while line:
        # non-empty line
        if not line.isspace():
            line = line.strip()
            line_split = line.split("\t")

            # Catch/skip ill-formed lines
            if len(line_split) != n_columns:
                print(
                    f"Line {line_cnt+1} does not have length "
                    f"{n_columns} but {len(line_split)} skip line: '{line}'"
                )
            else:
                # build up sentences
                for i in range(n_columns):
                    columns[i][sent_cnt].append(line_split[i])

        # empty line
        else:
            # current sentence empty?
            # then just replace with empty sentence again
            if columns[0][sent_cnt] == []:
                for i in range(n_columns):
                    columns[i][sent_cnt] = []
            # else: move to build next sentence
            else:
                for i in range(n_columns):
                    columns[i].append([])
                sent_cnt += 1

        # Move on
        line = file_obj.readline()
        line_cnt += 1

    # optional: flatten structure
    if not keep_sentences:
        columns = [list(itertools.chain(*col)) for col in columns]

    return columns

def load_dtaeval_all() -> datasets.DatasetDict:

    datadir = "/home/bracke/data/dta/dtaeval/split-v3.1/txt"

    train_path = os.path.join(datadir, "train")
    validation_path = os.path.join(datadir, "dev")
    test_path = os.path.join(datadir, "test")

    ds = datasets.DatasetDict()
    ds["train"] = load_dtaeval_as_dataset(train_path)
    ds["validation"] = load_dtaeval_as_dataset(validation_path)
    ds["test"] = load_dtaeval_as_dataset(test_path)

    return ds

class CustomNormalizer:
    def normalize(self, normalized: NormalizedString):
        normalized.nfd() # unicode decomposition
        normalized.replace("ſ", "s")
        normalized.replace("ꝛ", "r")
        normalized.replace(chr(0x0303), "") # drop combining tilde
        normalized.replace(chr(0x0364), "e")
        normalized.replace("æ", "ae")
        normalized.replace("ů","ü")
        normalized.replace("Ů","Ü")
        normalized.nfc()


# GPU set-up
device = torch.device('cuda:1' if torch.cuda.is_available() else "cpu")

################# Load tokenizers #######################
tokenizer_hmbert_custom = AutoTokenizer.from_pretrained("dbmdz/bert-base-historic-multilingual-cased")
tokenizer_hmbert_custom.backend_tokenizer.normalizer = Normalizer.custom(CustomNormalizer())
tokenizer_bert = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")


##################  Load data ###########################
print("Loading the data ...")

dta_dataset = load_dtaeval_all()

dta_dataset["validation"] = (
    dta_dataset["validation"].shuffle().select(range(10)) # TODO
)


Loading the data ...


In [64]:
# ################## Load model ##########
checkpoint = "models/models_2023-02-14_12-29/checkpoint-31000"
model = transformers.EncoderDecoderModel.from_pretrained(
    checkpoint, 
    ).to(device)


def generate_normalization(batch):
    inputs = tokenizer_hmbert_custom(batch["orig"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    outputs = model.generate(input_ids, attention_mask=attention_mask)

    output_str = tokenizer_bert.batch_decode(outputs, skip_special_tokens=True)

    # batch["norm_pred_ids"] = outputs
    batch["norm_pred_str"] = output_str

    return batch


batch_size = 4  # change to 64 for full evaluation

results = dta_dataset["validation"].map(
    generate_normalization, 
    batched=True, 
    batch_size=batch_size, 
    )


100%|██████████| 3/3 [00:15<00:00,  5.30s/ba]


In [55]:
print(results[7])

{'orig': 'Unaufhörlich rief er ſich jene Begebenheit zurück , welche einen unauslöſchlichen Eindruck auf ſein Gemüth gemacht hatte .', 'norm': 'Unaufhörlich rief er sich jene Begebenheit zurück , welche einen unauslöschlichen Eindruck auf sein Gemüt gemacht hatte .', 'norm_pred_str': 'Unaufhörlich rief er sich jene Begebenheit zurück, welche einen unauslöslichen Eindruck auf sein Gemüt gemacht hatte.'}


In [65]:
test_sents = ["Jhm werden Licht und Feur anflammen das Geſicht.",
              "Der mit den Froͤſchen frech die Cedern-Baͤum’ anſchreyt.",
              "Jn einem Gespräch vorgestellet/ Wie nehmlich/ durch Göttlichen Beystand eine wohl-unterrichtete und geübte Wehe-Mutter/ Mit Verstand und geschickter Hand/ dergleichen verhüten/ oder wanns Noth ist/ das Kind wenden könne",
              "Gib mir bey Zeiten bescheid, daß es noch etwas zu küßen gibt.",
              "Gehts noch?"
]
inputs = tokenizer_hmbert_custom(test_sents, padding="max_length", truncation=True, max_length=256, return_tensors="pt")

# print(tokenizer_hmbert_custom.tokenize("Jhm werden Licht und Feur anflammen das Geſicht."))


input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=128)
output_str = tokenizer_bert.batch_decode(outputs, skip_special_tokens=True)

print(output_str)

['Ihm werden Licht und Feuer anflammen das Gesicht.', 'Der mit den Fröschen frech die Zedernbäume anschreit.', 'In einem Gespräch vorgestellt... Wie nämlich, durch Göttlichen Beistand eine wohlunterrichtete und geübte Wehe - Mutter ( Mit Verstand und geschickter Hand ( dergleichen verhüten ) oder wann Not ist, das Kind wenden könne', 'Gib mir bei Zeiten bescheid, daß es noch etwas zu küssen gibt.', 'Geht es noch?']
