In [123]:
from datasets import load_dataset, DatasetDict
from numpy import random
import re

In [124]:
substitutions = [
    ("i", "y", "j"),
    ("ú", "ů", "u"),
    ("s", "z"),
    ("m", "n"),
    ("ý", "y", "ej"),
    ("je", "ě", "e"),
    ("mě", "mně", "mne"),
    ("h", "ch"),
    ("p", "b"),

    (" ", ", "),
    (".", ",", " "),
    ("?", "."),
]

defined_substitution_dict = {
    letter: tuple(set(letter_group) - set([letter])) for letter_group in substitutions for letter in letter_group 
}


# Insertions, deletions, substitutions, transpositions
# Small / capital letters (Který -> který, KTerý, KTERÝ)
# Diacritics -> no diacritics

diacritics = "á, č, ď, é, ě, í, ň, ó, ř, š, ť, ú, ů, ý, ž".split(", ")
without_diacritics = list("acdeeinorstuuyz")
alphabet = "a, á, b, c, č, d, ď, e, é, ě, f, g, h, ch, i, í, j, k, l, m, n, ň, o, ó, p, q, r, ř, s, š, t, ť, u, ú, ů, v, w, x, y, ý, z, ž,  ".split(", ")

# Double characters (r -> rr)

MAX_TRANSPOSITION_DISTANCE = 2


word_errors = {
    "bychom": ("by jsme", "bysme"),
    "byste": ("by jste",),
    "bys": ("by jsi",)
}

In [125]:
def copy_capitalizing(word, pattern_word):
    if pattern_word == pattern_word.lower():
        return word.lower()
    if pattern_word[0] == pattern_word[0].upper() and pattern_word[1:] == pattern_word[1:].lower():
        return word[0].upper() + word[1:]
    
    if pattern_word == pattern_word.upper():
        return word.upper()
    

    min_length = min(len(word), len(pattern_word))
    for i, char in enumerate(pattern_word[:min_length]):
        if char == char.lower():
            word = word[:i] + word[i].lower() + word[(i+1):]
        else:
            word = word[:i] + word[i].upper() + word[(i+1):]
    return word

In [126]:
def starts_with_capital_letter(word: str):
    return word[0] == word[0].upper()

def contains_diacritics(word: str):
    for letter in diacritics:
        if letter in word.lower():
            return True
    return False

def contains_special_words(word: str):
    for word_error in word_errors.keys():
        if word_error in word.lower():
            return True
    return False

def contains_letters_for_substitutions(word: str):
    for letter in defined_substitution_dict.keys():
        if letter in word.lower():
            return True
    return False

In [127]:
def change_capital_letters(word: str):
    probs = [0.7, 0.2, 0.1]  # (Který -> který, KTerý, KTERÝ)

    if len(word) < 2:
        probs = [1.0, 0.0, 0.0]

    lower_all = lambda x: x.lower()
    change_second = lambda x: x[:2].upper() + x[2:]
    upper_all = lambda x: x.upper()

    funcs = [lower_all, change_second, upper_all]

    used_func = random.choice(funcs, p=probs)
    return used_func(word)


def substitute_word(word: str):
    word_errors_list = list(word_errors.items())
    #random.shuffle(word_errors_list)
    
    for word_error, replacements in word_errors_list:
        if word_error not in word.lower():
            continue
        
        error_start_index = word.lower().index(word_error)
        if error_start_index >= 0:
            end_index = error_start_index + len(word_error)
            replacement = random.choice(replacements)

            original_substring = word[error_start_index:end_index]

            word = word[:error_start_index] + copy_capitalizing(replacement, original_substring) + word[end_index:]
            return word
    else:
        raise ValueError(f'No word errors could be substituted in "{word}"!')


def substitute_letter(word: str, substitution_dict: dict):
    letters_to_substitute = dict()
    regex_special_symbols = ["?", ".", "+", "*", ")", "(", "]", "["]
    for letter in substitution_dict:
        if letter in word.lower():
            letter_for_re = letter
            if letter in regex_special_symbols:
                letter_for_re = rf"\{letter}"

            letters_to_substitute[letter] = [m.start() for m in re.finditer(letter_for_re, word.lower())]

    letter_to_sub = random.choice(list(letters_to_substitute.keys()))
    letter_index = random.choice(letters_to_substitute[letter_to_sub])
    substitute_by = random.choice(substitution_dict[letter_to_sub])

    return copy_capitalizing(word[:letter_index] + substitute_by + word[letter_index+len(letter_to_sub):], word)


def substitute_defined_letters(word: str):
    return substitute_letter(word, defined_substitution_dict)

def substitute_diacritics(word: str):
    substitution_dict = {diac: [without_diac] for diac, without_diac in zip(diacritics, without_diacritics)}
    substitution_dict |= {value[0]: [key] for key, value in substitution_dict.items()}
    return substitute_letter(word, substitution_dict)


def substitute(word: str):
    if word == "":
        return word
    probs = [0.8, 0.3, 0.2, 0.1]  # word_sub_prob, letter_sub_prob, diacritics_sub_prob, random_substitution

    if not contains_special_words(word):
        probs[0] = 0.0
    if not contains_letters_for_substitutions(word):
        probs[1] = 0.0
    if not contains_diacritics(word):
        probs[2] = 0.0
    
    probs = [prob / sum(probs) for prob in probs]


    # TODO: use probabilities based on langugage histogram for random.choice(alphabet)
    random_substitution = lambda x: substitute_letter(
        x, 
        {
            random.choice(list(set(x.lower()))): [random.choice(alphabet)]
        }
    )

    funcs = [substitute_word, substitute_defined_letters, substitute_diacritics, random_substitution]

    used_func = random.choice(funcs, p=probs)
    return used_func(word)

def delete_letter(word: str):
    if word == "":
        return word
    index_to_delete = random.choice(list(range(len(word))))
    return word[:index_to_delete] + word[index_to_delete+1:]


def transpose(word: str):
    if len(word) <= 1:
        return word
    
    word_indices = list(range(len(word)))
    first_letter_index = random.choice(word_indices)

    # Find second index based on maximal transposition distance
    new_range = list(range(max(0, first_letter_index - MAX_TRANSPOSITION_DISTANCE), min(len(word), first_letter_index + MAX_TRANSPOSITION_DISTANCE)))
    del new_range[new_range.index(first_letter_index)]
    second_letter_index = random.choice(new_range)

    word_list = list(word)
    word_list[first_letter_index], word_list[second_letter_index] = word_list[second_letter_index], word_list[first_letter_index]
    return copy_capitalizing("".join(word_list), word)



def insert(word: str):
    probs = [0.8, 0.2] # Duplicate letter, random letter

    inserted_letter_random = random.choice(alphabet)
    if word == "":
        return inserted_letter_random

    inserted_index = random.choice(list(range(len(word))))
    inserted_letter_duplicate = word[inserted_index]

    inserted_letter = random.choice([inserted_letter_duplicate, inserted_letter_random], p=probs)

    changed_word = word[:inserted_index] + inserted_letter + word[inserted_index:]
    return copy_capitalizing(changed_word, word)


In [128]:
# insertion, deletion, substitution, transposition, capital_letters, no change
EDIT_PROBABILITIES = (0.1, 0.2, 0.3, 0.2, 0.1, 0.1)
assert sum(EDIT_PROBABILITIES) == 1.0

MAX_CHANGES_PER_STRING = 5


In [129]:
def apply_changes(word: str):
    changes = [insert, delete_letter, substitute, transpose, change_capital_letters, lambda x: x]
    for _ in range(min(len(word), MAX_CHANGES_PER_STRING)):
        changing_func = random.choice(changes, p=EDIT_PROBABILITIES)
        try:
            word = changing_func(word)
        except Exception as e:
            print(e)
    return word

In [130]:
for _ in range(100):
    word = "Ó, náhlý déšť již zvířil prach a čilá laň teď běží s houfcem gazel Ualdewara k exkluzívním úkrytům! Ty bychom neměli zmeškat"

    apply_changes(word)

## Data Augmentation on dataset

In [131]:
news_dataset = load_dataset("data/czech_news_dataset_v2").remove_columns(
    ["headline", "brief", "url", "authors", "category", "comments_num", "server", "category_unclean", 'day_of_week', "date", "authors_cum_gender", "authors_gender", "keywords"]
)
news_dataset = news_dataset.rename_column("content", "target")
news_dataset

DatasetDict({
    train: Dataset({
        features: ['target'],
        num_rows: 1641471
    })
})

In [132]:
sentence_end_symbols = ["?", "!", "."]
sentence_end = "</s>"

def split_to_sentences(batch):
    sequences = batch["target"]
    sentences = []
    
    for sequence in sequences:
        for symbol in sentence_end_symbols:
            sequence = sequence.replace(symbol, f"{symbol}{sentence_end}")
            sequence = sequence.replace(f"{sentence_end} ", sentence_end)
    
        sentences += sequence.split(sentence_end)
    sentences = [sentence.strip() for sentence in sentences if len(sentence.strip()) > 1]
    batch["target"] = sentences
    return batch

def add_errors(sequence_dict):
    error_sentence = apply_changes(sequence_dict["target"])
    sequence_dict["source"] = error_sentence
    return sequence_dict

In [133]:
tiny_dataset = news_dataset["train"].select(list(range(0, 100_000)))

segmented_tiny = tiny_dataset.map(split_to_sentences, batched=True)
tiny_with_errors = segmented_tiny.map(add_errors)

Map: 100%|██████████| 100000/100000 [00:07<00:00, 12765.74 examples/s]
Map:  60%|█████▉    | 1267055/2118584 [25:23<15:46, 899.36 examples/s] 

bad escape (end of pattern) at position 0
bad escape (end of pattern) at position 0


Map:  60%|█████▉    | 1267353/2118584 [25:24<16:28, 861.17 examples/s]

bad escape (end of pattern) at position 0


Map:  62%|██████▏   | 1315713/2118584 [26:17<14:55, 896.54 examples/s] 

bad escape (end of pattern) at position 0


Map:  82%|████████▏ | 1733282/2118584 [34:23<08:11, 783.83 examples/s] 

bad escape (end of pattern) at position 0


Map: 100%|██████████| 2118584/2118584 [41:54<00:00, 842.51 examples/s] 


In [134]:
train_testvalid = tiny_with_errors.train_test_split(test_size=0.3)

test_valid = train_testvalid['test'].train_test_split(test_size=0.5)

train_test_valid_dataset = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'dev': test_valid['train']})

In [136]:
train_test_valid_dataset.save_to_disk("data/czech_news_errors")

Saving the dataset (1/1 shards): 100%|██████████| 1483008/1483008 [00:32<00:00, 45744.25 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 317788/317788 [00:02<00:00, 141371.51 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 317788/317788 [00:02<00:00, 157964.08 examples/s]
