# Cross Validation

In [1]:
import random
import os

In [2]:
def read_file(fname):
    data = []
    with open(fname) as f:
        for line in f:
            data.append(line.strip())
    return data

In [3]:
def get_data(tgt_lang, domain):
    source = read_file(f"../data/train/en-{tgt_lang}/formality-control.train.{domain}.en-{tgt_lang}.en")
    formal_translations = read_file(f"../data/train/en-{tgt_lang}/formality-control.train.{domain}.en-{tgt_lang}.formal.{tgt_lang}")
    informal_translations = read_file(f"../data/train/en-{tgt_lang}/formality-control.train.{domain}.en-{tgt_lang}.informal.{tgt_lang}")
    formal_translations_annotated = read_file(f"../data/train/en-{tgt_lang}/formality-control.train.{domain}.en-{tgt_lang}.formal.annotated.{tgt_lang}")
    informal_translations_annotated = read_file(f"../data/train/en-{tgt_lang}/formality-control.train.{domain}.en-{tgt_lang}.informal.annotated.{tgt_lang}")
    return {domain + ".en": source,
           domain + ".formal." + tgt_lang: formal_translations,
           domain + ".informal." + tgt_lang:informal_translations,
           domain + ".formal.annotated." + tgt_lang: formal_translations_annotated,
           domain + ".informal.annotated." + tgt_lang: informal_translations_annotated}, len(source)

In [4]:
def write_lines(out_file, indices, data):
    with open(out_file, "w") as f:
        for i in indices:
            f.write(data[i] + "\n")  

In [5]:
def write(out_dir, indices, data):
    for split in indices:
        for key in data:
            write_lines(out_dir + split + "." + key, indices[split], data[key])

In [6]:
def get_split(x, test_size=50):
    indices = list(range(x))
    indices = random.sample(indices, len(indices))
    split = {}
    for i in range(0, x, test_size):
        index = int(i/test_size)
        split[index] = {}
        split[index]["dev"] = indices[i:i+test_size]
        split[index]["train"] = indices[0:i] + indices[i+test_size:]
        if len(split)==4:
            break
    return split

In [10]:
for tgt_lang in ["de", "es", "hi", "ja"]:
    for domain in ["telephony", "topical-chat"]:
        data, length = get_data(tgt_lang, domain)
        all_split = get_split(length)
        for split in all_split.keys():
            out_dir = f"../cross_val/internal_split{split}/en-{tgt_lang}/"
            os.makedirs(out_dir, exist_ok=True)
            write(out_dir, all_split[split], data)

In [None]:
def get_paired_dataset(base_dir,  split, tokenizer, formal_idx, informal_idx, skips_by_langpair_and_source=None):
    skips_by_langpair_and_source = skips_by_langpair_and_source or dict()
    datasets = []
    for dir_ in Path(base_dir).iterdir():
        src_lang, tgt_lang = dir_.name.split("-")
        skipped_lines_by_langpair = skips_by_langpair_and_source.get((src_lang, tgt_lang), {})

        topical_src, topical_formal, topical_informal = read_paired_topical(src_lang, tgt_lang, split, dir_, skipped_lines_by_langpair.get("topical"))
        telephony_src, telephony_formal, telephony_informal = read_paired_telephony(src_lang, tgt_lang, split, dir_, skipped_lines_by_langpair.get("telephony"))

        topical_src_encoded, topical_tgt_formal_encoded, topical_tgt_informal_encoded = encode_split(topical_src, topical_formal, topical_informal, tokenizer)
        length = len(topical_src_encoded)
        topical_formal_dataset = FormalityDataset(topical_src_encoded[:, topical_tgt_formal_encoded, formal_idx)
        topical_informal_dataset = FormalityDataset(topical_src_encoded, topical_tgt_informal_encoded, informal_idx)

        telephony_src_encoded, telephony_tgt_formal_encoded, telephony_tgt_informal_encoded = encode_split(telephony_src, telephony_formal, telephony_informal, tokenizer)

        telephony_formal_dataset = FormalityDataset(telephony_src_encoded, telephony_tgt_formal_encoded, formal_idx)
        telephony_informal_dataset = FormalityDataset(telephony_src_encoded, telephony_tgt_informal_encoded, informal_idx)

        dataset = torch.utils.data.ConcatDataset((topical_formal_dataset, topical_informal_dataset, telephony_formal_dataset, telephony_informal_dataset))
        datasets.append(dataset)
    return torch.utils.data.ConcatDataset(datasets)
