-
Notifications
You must be signed in to change notification settings - Fork 1
/
build_kftt_dataset.py
109 lines (87 loc) · 3.45 KB
/
build_kftt_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import itertools
import json
from collections import Counter, OrderedDict
from pathlib import Path
from shutil import unpack_archive
from urllib.request import urlretrieve
from libs.text_encoder import get_tokenized_text_list
def get_datasets(train_file_path: Path, val_file_path: Path, test_file_path: Path, lang: str):
train_tokenized_texts = get_tokenized_text_list(train_file_path, lang)
val_tokenized_texts = get_tokenized_text_list(val_file_path, lang)
test_tokenized_texts = get_tokenized_text_list(test_file_path, lang)
counter = Counter(list(itertools.chain.from_iterable(train_tokenized_texts)))
results = {
"train_texts": train_tokenized_texts,
"val_texts": val_tokenized_texts,
"test_texts": test_tokenized_texts,
"word_freqs": OrderedDict(sorted(counter.items(), key=lambda x: x[1], reverse=True)),
}
return results
def write_parameter_settings(base_path: Path):
settings = {
"params": {
"n_dim": 128,
"hidden_dim": 256,
"n_enc_blocks": 2,
"n_dec_blocks": 2,
"head_num": 8,
"dropout_rate": 0.1,
},
"training": {
"batch_size": 128,
"num_epoch": 20,
},
"min_freq": {
"source": 3,
"target": 3,
},
}
with (base_path / "settings.json").open("w") as f:
json.dump(settings, f, indent=2, ensure_ascii=False)
def main():
# データセットのアドレスと保存ファイル名
url = "http://www.phontron.com/kftt/download/kftt-data-1.0.tar.gz"
file_name = "kftt-data-1.0.tar.gz"
base_path = Path(__file__).resolve().parent / "kftt_dataset"
base_path.mkdir(exist_ok=True, parents=True)
archive_path = base_path / file_name
# 存在しない場合にファイルをダウンロードする
if not archive_path.exists():
print(f"download {file_name}")
urlretrieve(url, archive_path)
# zipファイルを展開する
print(f"expand {file_name}")
unpack_archive(archive_path, base_path, format="gztar")
# 分かち書き済みのデータを使う
dataset_path = base_path / "kftt-data-1.0" / "data" / "tok"
print("create source files...")
results = get_datasets(
dataset_path / "kyoto-train.cln.en",
dataset_path / "kyoto-dev.en",
dataset_path / "kyoto-test.en",
lang="spaced",
)
for key in ["train_texts", "val_texts", "test_texts"]:
with (base_path / f"src_{key}.txt").open("w") as f:
for tokenized_text in results[key]:
f.write(" ".join(tokenized_text) + "\n")
with (base_path / "src_word_freqs.json").open("w") as f:
json.dump(results["word_freqs"], f, indent=2, ensure_ascii=False)
print("create target files...")
results = get_datasets(
dataset_path / "kyoto-train.cln.ja",
dataset_path / "kyoto-dev.ja",
dataset_path / "kyoto-test.ja",
lang="spaced",
)
for key in ["train_texts", "val_texts", "test_texts"]:
with (base_path / f"tgt_{key}.txt").open("w") as f:
for tokenized_text in results[key]:
f.write(" ".join(tokenized_text) + "\n")
with (base_path / "tgt_word_freqs.json").open("w") as f:
json.dump(results["word_freqs"], f, indent=2, ensure_ascii=False)
print("write parameter settings")
write_parameter_settings(base_path)
print("done.")
if __name__ == "__main__":
main()