-
Notifications
You must be signed in to change notification settings - Fork 1
/
loader.py
142 lines (117 loc) · 5.39 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import json
import random
import torch
import numpy as np
from tree import Tree, head_to_tree, tree_to_adj
class DataLoader(object):
def __init__(self, filename, batch_size, args, dicts):
self.batch_size = batch_size
self.args = args
self.dicts = dicts
with open(filename) as infile:
data = json.load(infile)
# preprocess data
data = self.preprocess(data, dicts, args)
# chunk into batches
data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
self.data = data
print("{} batches created for {}".format(len(data), filename))
def preprocess(self, data, dicts, args):
processed = []
for d in data:
for aspect in d['aspects']:
# word token
tok = list(d['token'])
if args.lower == True:
tok = [t.lower() for t in tok]
asp = list(aspect['term']) # aspect
label = aspect['polarity'] # label
pos = list(d['pos']) # pos
head = list(d['head']) # head
dep = list(d['deprel']) #deprel
length = len(tok) # real length
# position
post = [i-aspect['from'] for i in range(aspect['from'])] \
+ [0 for _ in range(aspect['from'], aspect['to'])] \
+ [i-aspect['to']+1 for i in range(aspect['to'], length)]
# mask of aspect
if len(asp) == 0:
mask = [1 for _ in range(length)] # for rest16
else:
mask = [0 for _ in range(aspect['from'])] \
+[1 for _ in range(aspect['from'], aspect['to'])] \
+[0 for _ in range(aspect['to'], length)]
# map to ids
tok = map_to_ids(tok, dicts['token'])
asp = map_to_ids(asp, dicts['token'])
label = dicts['polarity'][label]
pos = map_to_ids(pos, dicts['pos'])
dep = map_to_ids(dep, dicts['dep'])
head = [int(x) for x in head]
assert any([x == 0 for x in head])
post = map_to_ids(post, dicts['post'])
assert len(tok) == length \
and len(pos) == length \
and len(head) == length \
and len(post) == length \
and len(mask) == length
processed += [(tok, asp, pos, head, dep, post, mask, length, label)]
return processed
def gold(self):
return self.labels
def __len__(self):
return len(self.data)
def __getitem__(self, key):
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
# sort all fields by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, lens)
# convert to tensors
tok = get_long_tensor(batch[0], batch_size).cuda()
asp = get_long_tensor(batch[1], batch_size).cuda()
pos = get_long_tensor(batch[2], batch_size).cuda()
head = get_long_tensor(batch[3], batch_size).cuda()
dep = get_long_tensor(batch[4], batch_size).cuda()
post = get_long_tensor(batch[5], batch_size).cuda()
mask = get_float_tensor(batch[6], batch_size).cuda()
length = torch.LongTensor(batch[7]).cuda()
label = torch.LongTensor(batch[8]).cuda()
def inputs_to_tree_reps(maxlen, head, words, l):
trees = [head_to_tree(head[i], words[i], l[i]) for i in range(l.size(0))]
adj = [tree_to_adj(maxlen, tree, directed=self.args.direct, self_loop=self.args.loop).reshape(1, maxlen, maxlen) for tree in trees]
adj = np.concatenate(adj, axis=0)
return adj
maxlen = max(length)
adj = torch.tensor(inputs_to_tree_reps(maxlen, head, tok, length)).cuda()
return (tok, asp, pos, head, dep, post, mask, length, adj, label)
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def map_to_ids(tokens, vocab):
ids = [vocab[t] if t in vocab else 1 for t in tokens] # the id of [UNK] is ``1''
return ids
def get_long_tensor(tokens_list, batch_size):
""" Convert list of list of tokens to a padded LongTensor. """
token_len = max(len(x) for x in tokens_list)
tokens = torch.LongTensor(batch_size, token_len).fill_(0)
for i, s in enumerate(tokens_list):
tokens[i, :len(s)] = torch.LongTensor(s)
return tokens
def get_float_tensor(tokens_list, batch_size):
""" Convert list of list of tokens to a padded FloatTensor. """
token_len = max(len(x) for x in tokens_list)
tokens = torch.FloatTensor(batch_size, token_len).fill_(0)
for i, s in enumerate(tokens_list):
tokens[i, :len(s)] = torch.FloatTensor(s)
return tokens
def sort_all(batch, lens):
""" Sort all fields by descending order of lens, and return the original indices. """
unsorted_all = [lens] + [range(len(lens))] + list(batch)
sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
return sorted_all[2:], sorted_all[1]