-
Notifications
You must be signed in to change notification settings - Fork 118
/
crf_tagger.py
executable file
·206 lines (175 loc) · 9.05 KB
/
crf_tagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from __future__ import absolute_import
#typing
#overrides
import torch
from torch.nn.modules.linear import Linear
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
from allennlp.modules import ConditionalRandomField, FeedForward
from allennlp.modules.conditional_random_field import allowed_transitions
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator, RegularizerApplicator
import allennlp.nn.util as util
from allennlp.training.metrics import SpanBasedF1Measure
class CrfTagger(Model):
u"""
The ``CrfTagger`` encodes a sequence of text with a ``Seq2SeqEncoder``,
then uses a Conditional Random Field model to predict a tag for each token in the sequence.
Parameters
----------
vocab : ``Vocabulary``, required
A Vocabulary, required in order to compute sizes for input/output projections.
text_field_embedder : ``TextFieldEmbedder``, required
Used to embed the tokens ``TextField`` we get as input to the model.
encoder : ``Seq2SeqEncoder``
The encoder that we will use in between embedding tokens and predicting output tags.
label_namespace : ``str``, optional (default=``labels``)
This is needed to compute the SpanBasedF1Measure metric.
Unless you did something unusual, the default value should be what you want.
feedforward : ``FeedForward``, optional, (default = None).
An optional feedforward layer to apply after the encoder.
dropout: ``float``, optional (detault=``None``)
verbose_metrics : ``bool``, optional (default = False)
If true, metrics will be returned per label class in addition
to the overall statistics.
constraint_type : ``str``, optional (default=``None``)
If provided, the CRF will be constrained at decoding time
to produce valid labels based on the specified type (e.g. "BIO", or "BIOUL").
include_start_end_transitions : ``bool``, optional (default=``True``)
Whether to include start and end transition parameters in the CRF.
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
Used to initialize the model parameters.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self, vocab ,
text_field_embedder ,
encoder ,
label_namespace = u"labels",
constraint_type = None,
feedforward = None,
include_start_end_transitions = True,
dropout = None,
verbose_metrics = False,
initializer = InitializerApplicator(),
regularizer = None) :
super(CrfTagger, self).__init__(vocab, regularizer)
self.label_namespace = label_namespace
self.text_field_embedder = text_field_embedder
self.num_tags = self.vocab.get_vocab_size(label_namespace)
self.encoder = encoder
self._verbose_metrics = verbose_metrics
if dropout:
self.dropout = torch.nn.Dropout(dropout)
else:
self.dropout = None
self._feedforward = feedforward
if feedforward is not None:
output_dim = feedforward.get_output_dim()
else:
output_dim = self.encoder.get_output_dim()
self.tag_projection_layer = TimeDistributed(Linear(output_dim,
self.num_tags))
if constraint_type is not None:
labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
constraints = allowed_transitions(constraint_type, labels)
else:
constraints = None
self.crf = ConditionalRandomField(
self.num_tags, constraints,
include_start_end_transitions=include_start_end_transitions
)
self.span_metric = SpanBasedF1Measure(vocab,
tag_namespace=label_namespace,
label_encoding=constraint_type or u"BIO")
check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
u"text field embedding dim", u"encoder input dim")
if feedforward is not None:
check_dimensions_match(encoder.get_output_dim(), feedforward.get_input_dim(),
u"encoder output dim", u"feedforward input dim")
initializer(self)
#overrides
def forward(self, # type: ignore
tokens ,
tags = None,
metadata = None) :
# pylint: disable=arguments-differ
u"""
Parameters
----------
tokens : ``Dict[str, torch.LongTensor]``, required
The output of ``TextField.as_array()``, which should typically be passed directly to a
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
for the ``TokenIndexers`` when you created the ``TextField`` representing your
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
which knows how to combine different word representations into a single vector per
token in your input.
tags : ``torch.LongTensor``, optional (default = ``None``)
A torch tensor representing the sequence of integer gold class labels of shape
``(batch_size, num_tokens)``.
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
metadata containg the original words in the sentence to be tagged under a 'words' key.
Returns
-------
An output dictionary consisting of:
logits : ``torch.FloatTensor``
The logits that are the output of the ``tag_projection_layer``
mask : ``torch.LongTensor``
The text field mask for the input tokens
tags : ``List[List[int]]``
The predicted tags using the Viterbi algorithm.
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
"""
embedded_text_input = self.text_field_embedder(tokens)
mask = util.get_text_field_mask(tokens)
if self.dropout:
embedded_text_input = self.dropout(embedded_text_input)
encoded_text = self.encoder(embedded_text_input, mask)
if self.dropout:
encoded_text = self.dropout(encoded_text)
if self._feedforward is not None:
encoded_text = self._feedforward(encoded_text)
logits = self.tag_projection_layer(encoded_text)
best_paths = self.crf.viterbi_tags(logits, mask)
# Just get the tags and ignore the score.
predicted_tags = [x for x, y in best_paths]
output = {u"logits": logits, u"mask": mask, u"tags": predicted_tags}
if tags is not None:
# Add negative log-likelihood as loss
log_likelihood = self.crf(logits, tags, mask)
output[u"loss"] = -log_likelihood
# Represent viterbi tags as "class probabilities" that we can
# feed into the `span_metric`
class_probabilities = logits * 0.
for i, instance_tags in enumerate(predicted_tags):
for j, tag_id in enumerate(instance_tags):
class_probabilities[i, j, tag_id] = 1
self.span_metric(class_probabilities, tags, mask)
if metadata is not None:
output[u"words"] = [x[u"words"] for x in metadata]
return output
#overrides
def decode(self, output_dict ) :
u"""
Converts the tag ids to the actual tags.
``output_dict["tags"]`` is a list of lists of tag_ids,
so we use an ugly nested list comprehension.
"""
output_dict[u"tags"] = [
[self.vocab.get_token_from_index(tag, namespace=self.label_namespace)
for tag in instance_tags]
for instance_tags in output_dict[u"tags"]
]
return output_dict
#overrides
def get_metrics(self, reset = False) :
metric_dict = self.span_metric.get_metric(reset=reset)
if self._verbose_metrics:
return metric_dict
else:
return dict((x, y) for x, y in list(metric_dict.items()) if u"overall" in x)
CrfTagger = Model.register(u"crf_tagger")(CrfTagger)