-
Notifications
You must be signed in to change notification settings - Fork 814
/
example.py
99 lines (88 loc) · 3.58 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import json
from functools import reduce
class Example(object):
"""Defines a single training or test example.
Stores each column of the example as an attribute.
"""
@classmethod
def fromJSON(cls, data, fields):
ex = cls()
obj = json.loads(data)
for key, vals in fields.items():
if vals is not None:
if not isinstance(vals, list):
vals = [vals]
for val in vals:
# for processing the key likes 'foo.bar'
name, field = val
ks = key.split('.')
def reducer(obj, key):
if isinstance(obj, list):
results = []
for data in obj:
if key not in data:
# key error
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
else:
results.append(data[key])
return results
else:
# key error
if key not in obj:
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
else:
return obj[key]
v = reduce(reducer, ks, obj)
setattr(ex, name, field.preprocess(v))
return ex
@classmethod
def fromdict(cls, data, fields):
ex = cls()
for key, vals in fields.items():
if key not in data:
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
if vals is not None:
if not isinstance(vals, list):
vals = [vals]
for val in vals:
name, field = val
setattr(ex, name, field.preprocess(data[key]))
return ex
@classmethod
def fromCSV(cls, data, fields, field_to_index=None):
if field_to_index is None:
return cls.fromlist(data, fields)
else:
assert(isinstance(fields, dict))
data_dict = {f: data[idx] for f, idx in field_to_index.items()}
return cls.fromdict(data_dict, fields)
@classmethod
def fromlist(cls, data, fields):
ex = cls()
for (name, field), val in zip(fields, data):
if field is not None:
if isinstance(val, str):
val = val.rstrip('\n')
# Handle field tuples
if isinstance(name, tuple):
for n, f in zip(name, field):
setattr(ex, n, f.preprocess(val))
else:
setattr(ex, name, field.preprocess(val))
return ex
@classmethod
def fromtree(cls, data, fields, subtrees=False):
try:
from nltk.tree import Tree
except ImportError:
print("Please install NLTK. "
"See the docs at http://nltk.org for more information.")
raise
tree = Tree.fromstring(data)
if subtrees:
return [cls.fromlist(
[' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)