In [None]:
from datasets import load_dataset

dataset_name = "iwslt2017"
dataset = load_dataset(dataset_name, "iwslt2017-zh-en", cache_dir="./cache")

In [None]:
import opencc
converter = opencc.OpenCC('s2t.json')

In [None]:
dataset

In [None]:
from collections import defaultdict
# Create dict for text into strokes translation and vice versa
with open("./vocab/zh2letter.txt", 'r', encoding="utf-8") as f:
    conversions = f.read()

conversions = conversions.splitlines()
dic = defaultdict(str)
stroke2word = defaultdict(str)
for line in conversions:
    chinese_char, strokes = line.split()
    dic[chinese_char] = strokes
    stroke2word[strokes] = chinese_char

Strokify

In [None]:
from functools import partial

def is_chinese(uchar):
    """判断一个unicode是否是汉字"""
    if (uchar >= u'\u4e00') and (uchar <= u'\u9fa5'):
        return True
    else:
        return False

def zh2letter(dictionary, line):
    char_set = set(list(line))
    newline = line
    for char in char_set:
        if is_chinese(char):
            newline = newline.replace(char, ' '+dictionary.get(char, '')+' ')
    return ' '.join(newline.split())+'\n'

In [None]:
TYPES = ["zh", "tz"]
NAMES = ["simp", "trad"]
TYPE = 0 # 0 for simplified, 1 for traditional

In [None]:
split="test"
if TYPE == 0:
    src_text = [pair["zh"] for pair in dataset[split]["translation"]]
else:
    src_text = [converter.convert(pair["zh"]) for pair in dataset[split]["translation"]]
trg_text = [pair["en"] for pair in dataset[split]["translation"]]

In [None]:
from tqdm import tqdm

src = TYPES[TYPE]
trg = "en"

func = partial(zh2letter, dic)
iter = map(func, src_text)

In [None]:
path = f"./data/NIST/{NAMES[TYPE]}/all"
with open(f"{path}/{split}.{src}-{trg}.{src}", 'w', encoding="utf-8") as f:
    for k in tqdm(iter): f.write(k)

with open(f"{path}/{split}.{src}-{trg}.{trg}", 'w', encoding="utf-8") as f:
    for k in tqdm(trg_text): f.write(f"{k}\n")

Split by average Token length

In [None]:
en_split = []
zh_split = []
for pair in dataset["test"]["translation"]:
    en_split.append(pair["en"])
    zh_split.append(pair["zh"])

In [None]:
from tqdm import tqdm
from functools import partial


TYPES = ["zh", "tz"]
NAMES = ["simp", "trad"]
TYPE = 0 # 0 for simplified, 1 for traditional

src = TYPES[TYPE]
trg = "en"

func = partial(zh2letter, dic)
iter = map(func, zh_split)

strokes = []
for k in tqdm(iter):
    strokes.append(k)

In [None]:
import numpy as np

avg_token_len = []
for sent in strokes:
    words = sent.split(" ")
    stroke_len = [len(word) for word in words]
    avg_token_len.append(np.average(stroke_len))
for p in [33, 66]:
    print(np.percentile(avg_token_len, p))

In [None]:
src_avg_strokes = defaultdict(list)
trg_avg_strokes = defaultdict(list)
for i, l in enumerate(avg_token_len):
    if l <= 6.531808510638299:
        src_avg_strokes["short"].append(zh_split[i])
        trg_avg_strokes["short"].append(en_split[i])
    elif l <= 7.136363636363637:
        src_avg_strokes["medium"].append(zh_split[i])
        trg_avg_strokes["medium"].append(en_split[i])
    else:
        src_avg_strokes["long"].append(zh_split[i])
        trg_avg_strokes["long"].append(en_split[i])

In [None]:
from tqdm import tqdm

src = TYPES[TYPE]
trg = "en"

func = partial(zh2letter, dic)
iter = map(func, src_text)

In [None]:
if split=="test":
    path = f"C:/Users/xk_20/Documents/Code/CS4248/StrokeNet/data/NIST/{NAMES[TYPE]}/test/sent"
else:
    path = f"C:/Users/xk_20/Documents/Code/CS4248/StrokeNet/data/NIST/{NAMES[TYPE]}"
# with open(f"{path}/{split}.{src}-{trg}.{src}", 'w', encoding="utf-8") as f:
#     for k in tqdm(iter): f.write(k)

# with open(f"{path}/{split}.{src}-{trg}.{trg}", 'w', encoding="utf-8") as f:
#     for k in tqdm(trg_text): f.write(f"{k}\n")

Split by length

In [None]:
en_split = []
zh_split = []
for pair in dataset["test"]["translation"]:
    en_split.append(pair["en"])
    zh_split.append(converter.convert(pair["zh"]))
lens = [len(zh) for zh in zh_split]
import numpy as np
for p in [33, 66]:
    print(np.percentile(lens, p))
# Split by length
from collections import defaultdict
sentence_by_length = defaultdict(list)
for pair in dataset["test"]["translation"]:
    if len(pair["zh"]) <= 18:
        sentence_by_length["short"].append(pair)
    elif len(pair["zh"]) <= 33:
        sentence_by_length["medium"].append(pair)
    else:
        sentence_by_length["long"].append(pair)
for type, sent in sentence_by_length.items():
    print(type, len(sent))

In [None]:
split="test"

src_text = defaultdict(list)
trg_text = defaultdict(list)
for type, sent in sentence_by_length.items():
    src_text[type] = [pair["zh"] for pair in sent]
    trg_text[type] = [pair["en"] for pair in sent]

In [None]:
if split=="test":
    path = f"./data/NIST/{NAMES[TYPE]}/test/sent"
else:
    path = f"./data/NIST/{NAMES[TYPE]}"
path

In [None]:
for word in sentence_by_length.keys():
    iter = map(func, src_text[word])
    with open(f"{path}/{split}-{word}.{src}-{trg}.{src}", 'w', encoding="utf-8") as f:
        for k in tqdm(iter): f.write(k)

    with open(f"{path}/{split}-{word}.{src}-{trg}.{trg}", 'w', encoding="utf-8") as f:
        for k in tqdm(trg_text[word]): f.write(f"{k}\n")

Finer granularity of average token length

In [None]:
import numpy as np

percentiles = defaultdict(int)
percentiles[0] = 0
for p in range(5, 100, 5):
    percentiles[p] = np.percentile(avg_token_len, p)
percentiles

In [None]:
# Split by length
from collections import defaultdict
from bisect import bisect_left
sentence_by_length2 = defaultdict(list)
for pair in dataset["test"]["translation"]:
    id = bisect_left(list(percentiles.values()), len(pair["zh"])) - 1
    p = list(percentiles.keys())[id]
    sentence_by_length2[p].append(pair)

In [None]:
from collections import defaultdict
from bisect import bisect_left
avg_strokes = defaultdict(list)
test = dataset["test"]["translation"]
for i, l in enumerate(avg_token_len):
    id = bisect_left(list(percentiles.values()), l) - 1
    p = list(percentiles.keys())[id]
    avg_strokes[p].append(test[i])

In [None]:
avg_strokes

In [None]:
split="test"

src_text2 = defaultdict(list)
trg_text2 = defaultdict(list)
for type, sent in avg_strokes.items():
    src_text2[type] = [pair["zh"] for pair in sent]
    trg_text2[type] = [pair["en"] for pair in sent]

In [None]:
path = f"C:/Users/xk_20/Documents/Code/CS4248/StrokeNet/data/NIST/simp_original/test/sent_fine"
split="test"
for word in avg_strokes.keys():
    iter = map(func, src_text2[word])
    with open(f"{path}/{split}-{word}.{src}-{trg}.{src}", 'w', encoding="utf-8") as f:
        for k in tqdm(iter): f.write(k)

    with open(f"{path}/{split}-{word}.{src}-{trg}.{trg}", 'w', encoding="utf-8") as f:
        for k in tqdm(trg_text2[word]): f.write(f"{k}\n")