-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_cdawg_matches_dawg.py
98 lines (82 loc) · 3.03 KB
/
test_cdawg_matches_dawg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from argparse import ArgumentParser
from tqdm import tqdm
from transformers import GPT2TokenizerFast
import numpy as np
from rusty_dawg import Dawg, Cdawg
def parse_args():
parser = ArgumentParser()
parser.add_argument("--train", type=str, default="data/wikitext-2-raw/wiki.train.raw")
parser.add_argument("--valid", type=str, default="data/wikitext-2-raw/wiki.valid.raw")
parser.add_argument("--n_valid", type=int, default=None)
return parser.parse_args()
def get_tokens(tokenizer, path):
all_tokens = []
for line in tqdm(open(path), desc=f"Open {path}..."):
tokens = tokenizer(line)["input_ids"]
all_tokens.extend(tokens)
all_tokens.append(Cdawg.EOS)
return all_tokens
def get_count(tokens, ngram):
"""Count actual # of occurrences of ngram in tokens in linear time"""
count = 0
for idx in range(0, len(tokens) - len(ngram)):
data_ngram = tokens[idx: idx + len(ngram)]
if data_ngram == ngram:
count += 1
return count
if __name__ == "__main__":
args = parse_args()
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
train = get_tokens(tokenizer, args.train)
valid = get_tokens(tokenizer, args.valid)
if args.n_valid is not None:
valid = valid[:args.n_valid]
print("Building DAWG...")
dawg = Dawg()
dawg.build(train)
cdawg = Cdawg(train)
print("Building CDAWG...")
cdawg.build()
print("Filling CDAWG counts...")
cdawg.fill_counts()
ds, length = (dawg.get_initial(), 0)
cs = cdawg.get_initial()
dlengths = []
clengths = []
dcounts = []
ccounts = []
for idx, token in enumerate(valid):
ds, length = dawg.transition_and_count(ds, token, length)
cs = cdawg.transition_and_count(cs, token)
dlengths.append(length)
clengths.append(cs.get_length())
dcounts.append(dawg.get_count(ds))
ccounts.append(cdawg.get_suffix_count(cs))
mismatched, = np.nonzero(np.array(clengths) != np.array(dlengths))
print("Mismatched length indices:", mismatched)
mismatched, = np.nonzero(np.array(ccounts) != np.array(dcounts))
print("Mismatched count indices:", mismatched)
print("CDAWG counts:", ccounts[:20])
print("DAWG counts:", dcounts[:20])
ngrams = [valid[idx + 1 - length: idx + 1] for idx, length in enumerate(clengths[:20])]
actual_counts = [get_count(train, ngram) for ngram in ngrams]
print("Actual counts:", actual_counts)
source = cdawg.get_source()
print("count(source) =", cdawg.get_count(source), "#(tokens) =", len(train))
import matplotlib.pyplot as plt
plt.figure()
plt.plot(clengths, label="CDAWG")
plt.plot(dlengths, label="DAWG")
plt.xlabel("validation token index")
plt.ylabel("suffix context length")
plt.tight_layout()
plt.legend()
plt.show()
plt.figure()
plt.plot(ccounts, label="CDAWG")
# plt.plot(dcounts, label="DAWG")
plt.xlabel("validation token index")
plt.ylabel("suffix context frequency")
plt.tight_layout()
plt.legend()
plt.show()