-
Notifications
You must be signed in to change notification settings - Fork 138
/
biaffine_dependency.py
270 lines (235 loc) · 11.1 KB
/
biaffine_dependency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn as nn
from supar.models import BiaffineDependencyModel
from supar.parsers.parser import Parser
from supar.utils import Config, Dataset, Embedding
from supar.utils.common import bos, pad, unk
from supar.utils.field import Field, SubwordField
from supar.utils.fn import ispunct
from supar.utils.logging import get_logger, progress_bar
from supar.utils.metric import AttachmentMetric
from supar.utils.transform import CoNLL
logger = get_logger(__name__)
class BiaffineDependencyParser(Parser):
r"""
The implementation of Biaffine Dependency Parser.
References:
- Timothy Dozat and Christopher D. Manning. 2017.
`Deep Biaffine Attention for Neural Dependency Parsing`_.
.. _Deep Biaffine Attention for Neural Dependency Parsing:
https://openreview.net/forum?id=Hk95PK9le
"""
NAME = 'biaffine-dependency'
MODEL = BiaffineDependencyModel
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.feat in ('char', 'bert'):
self.WORD, self.FEAT = self.transform.FORM
else:
self.WORD, self.FEAT = self.transform.FORM, self.transform.CPOS
self.ARC, self.REL = self.transform.HEAD, self.transform.DEPREL
self.puncts = torch.tensor([i
for s, i in self.WORD.vocab.stoi.items()
if ispunct(s)]).to(self.args.device)
def train(self, train, dev, test, buckets=32, batch_size=5000,
punct=False, tree=False, proj=False, verbose=True, **kwargs):
r"""
Args:
train/dev/test (list[list] or str):
Filenames of the train/dev/test datasets.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
punct (bool):
If ``False``, ignores the punctuations during evaluation. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding the unconsumed arguments that can be used to update the configurations for training.
"""
return super().train(**Config().update(locals()))
def evaluate(self, data, buckets=8, batch_size=5000,
punct=False, tree=True, proj=False, verbose=True, **kwargs):
r"""
Args:
data (str):
The data for evaluation, both list of instances and filename are allowed.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
punct (bool):
If ``False``, ignores the punctuations during evaluation. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding the unconsumed arguments that can be used to update the configurations for evaluation.
Returns:
The loss scalar and evaluation results.
"""
return super().evaluate(**Config().update(locals()))
def predict(self, data, pred=None, buckets=8, batch_size=5000,
prob=False, tree=True, proj=False, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
The number of tokens in each batch. Default: 5000.
prob (bool):
If ``True``, outputs the probabilities. Default: ``False``.
tree (bool):
If ``True``, ensures to output well-formed trees. Default: ``False``.
proj (bool):
If ``True``, ensures to output projective trees. Default: ``False``.
verbose (bool):
If ``True``, increases the output verbosity. Default: ``True``.
kwargs (dict):
A dict holding the unconsumed arguments that can be used to update the configurations for prediction.
Returns:
A :class:`~supar.utils.Dataset` object that stores the predicted results.
"""
return super().predict(**Config().update(locals()))
def _train(self, loader):
self.model.train()
bar, metric = progress_bar(loader), AttachmentMetric()
for words, feats, arcs, rels in bar:
self.optimizer.zero_grad()
mask = words.ne(self.WORD.pad_index)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_rel = self.model(words, feats)
loss = self.model.loss(s_arc, s_rel, arcs, rels, mask)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
self.optimizer.step()
self.scheduler.step()
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
# ignore all punctuation if not specified
if not self.args.punct:
mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
metric(arc_preds, rel_preds, arcs, rels, mask)
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
@torch.no_grad()
def _evaluate(self, loader):
self.model.eval()
total_loss, metric = 0, AttachmentMetric()
for words, feats, arcs, rels in loader:
mask = words.ne(self.WORD.pad_index)
# ignore the first token of each sentence
mask[:, 0] = 0
s_arc, s_rel = self.model(words, feats)
loss = self.model.loss(s_arc, s_rel, arcs, rels, mask)
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
self.args.tree,
self.args.proj)
# ignore all punctuation if not specified
if not self.args.punct:
mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
total_loss += loss.item()
metric(arc_preds, rel_preds, arcs, rels, mask)
total_loss /= len(loader)
return total_loss, metric
@torch.no_grad()
def _predict(self, loader):
self.model.eval()
preds = {}
arcs, rels, probs = [], [], []
for words, feats in progress_bar(loader):
mask = words.ne(self.WORD.pad_index)
# ignore the first token of each sentence
mask[:, 0] = 0
lens = mask.sum(1).tolist()
s_arc, s_rel = self.model(words, feats)
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
self.args.tree,
self.args.proj)
arcs.extend(arc_preds[mask].split(lens))
rels.extend(rel_preds[mask].split(lens))
if self.args.prob:
arc_probs = s_arc.softmax(-1)
probs.extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())])
arcs = [seq.tolist() for seq in arcs]
rels = [self.REL.vocab[seq.tolist()] for seq in rels]
preds = {'arcs': arcs, 'rels': rels}
if self.args.prob:
preds['probs'] = probs
return preds
@classmethod
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
r"""
Build a brand-new Parser, including initialization of all data fields and model parameters.
Args:
path (str):
The path of the model to be saved.
min_freq (str):
The minimum frequency needed to include a token in the vocabulary. Default: 2.
fix_len (int):
The max length of all subword pieces. The excess part of each piece will be truncated.
Required if using CharLSTM/BERT.
Default: 20.
kwargs (dict):
A dict holding the unconsumed arguments.
"""
args = Config(**locals())
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs(os.path.dirname(path), exist_ok=True)
if os.path.exists(path) and not args.build:
parser = cls.load(**args)
parser.model = cls.MODEL(**parser.args)
parser.model.load_pretrained(parser.WORD.embed).to(args.device)
return parser
logger.info("Building the fields")
WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
if args.feat == 'char':
FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len)
elif args.feat == 'bert':
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.bert)
args.max_len = min(args.max_len or tokenizer.max_len, tokenizer.max_len)
FEAT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
FEAT.vocab = tokenizer.get_vocab()
else:
FEAT = Field('tags', bos=bos)
ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs)
REL = Field('rels', bos=bos)
if args.feat in ('char', 'bert'):
transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL)
else:
transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL)
train = Dataset(transform, args.train)
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
FEAT.build(train)
REL.build(train)
args.update({
'n_words': WORD.vocab.n_init,
'n_feats': len(FEAT.vocab),
'n_rels': len(REL.vocab),
'pad_index': WORD.pad_index,
'unk_index': WORD.unk_index,
'bos_index': WORD.bos_index,
'feat_pad_index': FEAT.pad_index,
})
model = cls.MODEL(**args)
model.load_pretrained(WORD.embed).to(args.device)
return cls(args, model, transform)