-
Notifications
You must be signed in to change notification settings - Fork 287
/
wrapper.py
537 lines (448 loc) · 25.7 KB
/
wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains code for wrapping a transformer language model and
provides convenience methods for training and inference.
"""
import json
import jsonpickle
import os
from typing import List, Dict, Optional
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import RandomSampler, DataLoader, SequentialSampler
from tqdm import trange, tqdm
from transformers import InputExample, AdamW, get_linear_schedule_with_warmup, PreTrainedTokenizer, BertForMaskedLM, \
RobertaForMaskedLM, XLMRobertaForMaskedLM, XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer, \
XLNetLMHeadModel, BertConfig, BertForSequenceClassification, BertTokenizer, RobertaConfig, \
RobertaForSequenceClassification, RobertaTokenizer, XLMRobertaConfig, XLMRobertaForSequenceClassification, \
XLMRobertaTokenizer, AlbertForSequenceClassification, AlbertForMaskedLM, AlbertTokenizer, AlbertConfig, \
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import __version__ as transformers_version
import log
from pet import preprocessor
from pet.tasks import TASK_HELPERS
from pet.utils import InputFeatures, DictDataset, distillation_loss
logger = log.get_logger('root')
CONFIG_NAME = 'wrapper_config.json'
SEQUENCE_CLASSIFIER_WRAPPER = "sequence_classifier"
MLM_WRAPPER = "mlm"
PLM_WRAPPER = "plm"
WRAPPER_TYPES = [SEQUENCE_CLASSIFIER_WRAPPER, MLM_WRAPPER, PLM_WRAPPER]
PREPROCESSORS = {
SEQUENCE_CLASSIFIER_WRAPPER: preprocessor.SequenceClassifierPreprocessor,
MLM_WRAPPER: preprocessor.MLMPreprocessor,
PLM_WRAPPER: preprocessor.PLMPreprocessor,
}
MODEL_CLASSES = {
'bert': {
'config': BertConfig,
'tokenizer': BertTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: BertForSequenceClassification,
MLM_WRAPPER: BertForMaskedLM
},
'roberta': {
'config': RobertaConfig,
'tokenizer': RobertaTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: RobertaForSequenceClassification,
MLM_WRAPPER: RobertaForMaskedLM
},
'xlm-roberta': {
'config': XLMRobertaConfig,
'tokenizer': XLMRobertaTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: XLMRobertaForSequenceClassification,
MLM_WRAPPER: XLMRobertaForMaskedLM
},
'xlnet': {
'config': XLNetConfig,
'tokenizer': XLNetTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: XLNetForSequenceClassification,
PLM_WRAPPER: XLNetLMHeadModel
},
'albert': {
'config': AlbertConfig,
'tokenizer': AlbertTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: AlbertForSequenceClassification,
MLM_WRAPPER: AlbertForMaskedLM
},
'gpt2': {
'config': GPT2Config,
'tokenizer': GPT2Tokenizer,
MLM_WRAPPER: GPT2LMHeadModel
},
}
EVALUATION_STEP_FUNCTIONS = {
MLM_WRAPPER: lambda wrapper: wrapper.mlm_eval_step,
PLM_WRAPPER: lambda wrapper: wrapper.plm_eval_step,
SEQUENCE_CLASSIFIER_WRAPPER: lambda wrapper: wrapper.sequence_classifier_eval_step,
}
TRAIN_STEP_FUNCTIONS = {
MLM_WRAPPER: lambda wrapper: wrapper.mlm_train_step,
PLM_WRAPPER: lambda wrapper: wrapper.plm_train_step,
SEQUENCE_CLASSIFIER_WRAPPER: lambda wrapper: wrapper.sequence_classifier_train_step,
}
class WrapperConfig(object):
"""A configuration for a :class:`TransformerModelWrapper`."""
def __init__(self, model_type: str, model_name_or_path: str, wrapper_type: str, task_name: str, max_seq_length: int,
label_list: List[str], pattern_id: int = 0, verbalizer_file: str = None, cache_dir: str = None):
"""
Create a new config.
:param model_type: the model type (e.g., 'bert', 'roberta', 'albert')
:param model_name_or_path: the model name (e.g., 'roberta-large') or path to a pretrained model
:param wrapper_type: the wrapper type (one of 'mlm', 'plm' and 'sequence_classifier')
:param task_name: the task to solve
:param max_seq_length: the maximum number of tokens in a sequence
:param label_list: the list of labels for the task
:param pattern_id: the id of the pattern to use
:param verbalizer_file: optional path to a verbalizer file
:param cache_dir: optional path to a cache dir
"""
self.model_type = model_type
self.model_name_or_path = model_name_or_path
self.wrapper_type = wrapper_type
self.task_name = task_name
self.max_seq_length = max_seq_length
self.label_list = label_list
self.pattern_id = pattern_id
self.verbalizer_file = verbalizer_file
self.cache_dir = cache_dir
class TransformerModelWrapper:
"""A wrapper around a Transformer-based language model."""
def __init__(self, config: WrapperConfig):
"""Create a new wrapper from the given config."""
self.config = config
config_class = MODEL_CLASSES[self.config.model_type]['config']
tokenizer_class = MODEL_CLASSES[self.config.model_type]['tokenizer']
model_class = MODEL_CLASSES[self.config.model_type][self.config.wrapper_type]
model_config = config_class.from_pretrained(
config.model_name_or_path, num_labels=len(config.label_list), finetuning_task=config.task_name,
cache_dir=config.cache_dir if config.cache_dir else None, use_cache=False)
self.tokenizer = tokenizer_class.from_pretrained(
config.model_name_or_path,
cache_dir=config.cache_dir if config.cache_dir else None) # type: PreTrainedTokenizer
if self.config.model_type == 'gpt2':
self.tokenizer.pad_token, self.tokenizer.mask_token = self.tokenizer.eos_token, self.tokenizer.eos_token
self.model = model_class.from_pretrained(config.model_name_or_path, config=model_config,
cache_dir=config.cache_dir if config.cache_dir else None)
self.preprocessor = PREPROCESSORS[self.config.wrapper_type](self, self.config.task_name, self.config.pattern_id,
self.config.verbalizer_file)
self.task_helper = TASK_HELPERS[self.config.task_name](self) if self.config.task_name in TASK_HELPERS else None
@classmethod
def from_pretrained(cls, path: str) -> 'TransformerModelWrapper':
"""Load a pretrained wrapper from a given path."""
wrapper = TransformerModelWrapper.__new__(TransformerModelWrapper)
wrapper.config = wrapper._load_config(path)
tokenizer_class = MODEL_CLASSES[wrapper.config.model_type]['tokenizer']
model_class = MODEL_CLASSES[wrapper.config.model_type][wrapper.config.wrapper_type]
wrapper.model = model_class.from_pretrained(path)
wrapper.tokenizer = tokenizer_class.from_pretrained(path)
wrapper.preprocessor = PREPROCESSORS[wrapper.config.wrapper_type](
wrapper, wrapper.config.task_name, wrapper.config.pattern_id, wrapper.config.verbalizer_file)
wrapper.task_helper = TASK_HELPERS[wrapper.config.task_name](wrapper) \
if wrapper.config.task_name in TASK_HELPERS else None
return wrapper
def save(self, path: str) -> None:
"""Save a pretrained wrapper."""
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(path)
self.tokenizer.save_pretrained(path)
self._save_config(path)
def _save_config(self, path: str) -> None:
with open(os.path.join(path, CONFIG_NAME), 'w') as f:
f.write(jsonpickle.encode(self.config))
@staticmethod
def _load_config(path: str) -> WrapperConfig:
with open(os.path.join(path, CONFIG_NAME), 'r') as f:
return jsonpickle.decode(f.read())
def train(self, task_train_data: List[InputExample], device, per_gpu_train_batch_size: int = 8, n_gpu: int = 1,
num_train_epochs: int = 3, gradient_accumulation_steps: int = 1, weight_decay: float = 0.0,
learning_rate: float = 5e-5, adam_epsilon: float = 1e-8, warmup_steps=0, max_grad_norm: float = 1,
logging_steps: int = 50, per_gpu_unlabeled_batch_size: int = 8, unlabeled_data: List[InputExample] = None,
lm_training: bool = False, use_logits: bool = False, alpha: float = 0.8, temperature: float = 1,
max_steps=-1, **_):
"""
Train the underlying language model.
:param task_train_data: the training examples to use
:param device: the training device (cpu/gpu)
:param per_gpu_train_batch_size: the number of training examples per batch and gpu
:param n_gpu: the number of gpus to use
:param num_train_epochs: the number of epochs to train
:param gradient_accumulation_steps: the number of gradient accumulation steps before performing an update
:param weight_decay: the weight decay to use
:param learning_rate: the learning rate to use
:param adam_epsilon: epsilon parameter for the Adam optimizer
:param warmup_steps: the number of warmup steps
:param max_grad_norm: the maximum norm for the gradient
:param logging_steps: the number of steps after which logging information is printed
:param per_gpu_unlabeled_batch_size: the number of unlabeled examples per batch and gpu
:param unlabeled_data: the unlabeled examples to use
:param lm_training: whether to perform auxiliary language modeling (only for MLMs)
:param use_logits: whether to use the example's logits instead of their labels to compute the loss
:param alpha: the alpha parameter for auxiliary language modeling
:param temperature: the temperature for knowledge distillation
:param max_steps: the maximum number of training steps, overrides ``num_train_epochs``
:return: a tuple consisting of the total number of steps and the average training loss
"""
train_batch_size = per_gpu_train_batch_size * max(1, n_gpu)
train_dataset = self._generate_dataset(task_train_data)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size)
unlabeled_dataloader, unlabeled_iter = None, None
if lm_training or use_logits:
# we need unlabeled data both for auxiliary language modeling and for knowledge distillation
assert unlabeled_data is not None
unlabeled_batch_size = per_gpu_unlabeled_batch_size * max(1, n_gpu)
unlabeled_dataset = self._generate_dataset(unlabeled_data, labelled=False)
unlabeled_sampler = RandomSampler(unlabeled_dataset)
unlabeled_dataloader = DataLoader(unlabeled_dataset, sampler=unlabeled_sampler,
batch_size=unlabeled_batch_size)
unlabeled_iter = unlabeled_dataloader.__iter__()
if use_logits:
train_dataloader = unlabeled_dataloader
if max_steps > 0:
t_total = max_steps
num_train_epochs = max_steps // (max(1, len(train_dataloader) // gradient_accumulation_steps)) + 1
else:
t_total = len(train_dataloader) // gradient_accumulation_steps * num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
num_training_steps=t_total)
# multi-gpu training
if n_gpu > 1:
self.model = torch.nn.DataParallel(self.model)
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
self.model.zero_grad()
train_iterator = trange(int(num_train_epochs), desc="Epoch")
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration")
for step, batch in enumerate(epoch_iterator):
self.model.train()
unlabeled_batch = None
batch = {k: t.to(device) for k, t in batch.items()}
if lm_training:
while unlabeled_batch is None:
try:
unlabeled_batch = unlabeled_iter.__next__()
except StopIteration:
logger.info("Resetting unlabeled dataset")
unlabeled_iter = unlabeled_dataloader.__iter__()
lm_input_ids = unlabeled_batch['input_ids']
unlabeled_batch['input_ids'], unlabeled_batch['mlm_labels'] = self._mask_tokens(lm_input_ids)
unlabeled_batch = {k: t.to(device) for k, t in unlabeled_batch.items()}
train_step_inputs = {
'unlabeled_batch': unlabeled_batch, 'lm_training': lm_training, 'alpha': alpha,
'use_logits': use_logits, 'temperature': temperature
}
loss = self.task_helper.train_step(batch, **train_step_inputs) if self.task_helper else None
if loss is None:
loss = TRAIN_STEP_FUNCTIONS[self.config.wrapper_type](self)(batch, **train_step_inputs)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
if (step + 1) % gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
optimizer.step()
scheduler.step()
self.model.zero_grad()
global_step += 1
if logging_steps > 0 and global_step % logging_steps == 0:
logs = {}
loss_scalar = (tr_loss - logging_loss) / logging_steps
learning_rate_scalar = scheduler.get_lr()[0]
logs['learning_rate'] = learning_rate_scalar
logs['loss'] = loss_scalar
logging_loss = tr_loss
print(json.dumps({**logs, **{'step': global_step}}))
if 0 < max_steps < global_step:
epoch_iterator.close()
break
if 0 < max_steps < global_step:
train_iterator.close()
break
return global_step, (tr_loss / global_step if global_step > 0 else -1)
def eval(self, eval_data: List[InputExample], device, per_gpu_eval_batch_size: int = 8, n_gpu: int = 1,
priming: bool = False, decoding_strategy: str = 'default') -> Dict:
"""
Evaluate the underlying language model.
:param eval_data: the evaluation examples to use
:param device: the evaluation device (cpu/gpu)
:param per_gpu_eval_batch_size: the number of evaluation examples per batch and gpu
:param n_gpu: the number of gpus to use
:param priming: whether to use priming
:param decoding_strategy: the decoding strategy for PET with multiple masks ('default', 'ltr' or 'parallel')
:return: a dictionary of numpy arrays containing the indices, logits, labels, and (optional) question_ids for
each evaluation example.
"""
eval_dataset = self._generate_dataset(eval_data, priming=priming)
eval_batch_size = per_gpu_eval_batch_size * max(1, n_gpu)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size)
if n_gpu > 1:
self.model = torch.nn.DataParallel(self.model)
preds = None
all_indices, out_label_ids, question_ids = None, None, None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
self.model.eval()
batch = {k: t.to(device) for k, t in batch.items()}
labels = batch['labels']
indices = batch['idx']
with torch.no_grad():
# some tasks require special evaluation
logits = self.task_helper.eval_step(batch,
decoding_strategy=decoding_strategy) if self.task_helper else None
if logits is None:
logits = EVALUATION_STEP_FUNCTIONS[self.config.wrapper_type](self)(batch)
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = labels.detach().cpu().numpy()
all_indices = indices.detach().cpu().numpy()
if 'question_idx' in batch:
question_ids = batch['question_idx'].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)
all_indices = np.append(all_indices, indices.detach().cpu().numpy(), axis=0)
if 'question_idx' in batch:
question_ids = np.append(question_ids, batch['question_idx'].detach().cpu().numpy(), axis=0)
return {
'indices': all_indices,
'logits': preds,
'labels': out_label_ids,
'question_ids': question_ids
}
def _generate_dataset(self, data: List[InputExample], labelled: bool = True, priming: bool = False):
features = self._convert_examples_to_features(data, labelled=labelled, priming=priming)
feature_dict = {
'input_ids': torch.tensor([f.input_ids for f in features], dtype=torch.long),
'attention_mask': torch.tensor([f.attention_mask for f in features], dtype=torch.long),
'token_type_ids': torch.tensor([f.token_type_ids for f in features], dtype=torch.long),
'labels': torch.tensor([f.label for f in features], dtype=torch.long),
'mlm_labels': torch.tensor([f.mlm_labels for f in features], dtype=torch.long),
'logits': torch.tensor([f.logits for f in features], dtype=torch.float),
'idx': torch.tensor([f.idx for f in features], dtype=torch.long)
}
if self.config.wrapper_type == PLM_WRAPPER:
feature_dict['perm_mask'] = torch.tensor([f.perm_mask for f in features], dtype=torch.float)
feature_dict['target_mapping'] = torch.tensor([f.target_mapping for f in features], dtype=torch.float)
if self.task_helper:
self.task_helper.add_features_to_dict(features, feature_dict)
return DictDataset(**feature_dict)
def _convert_examples_to_features(self, examples: List[InputExample], labelled: bool = True,
priming: bool = False) -> List[InputFeatures]:
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logger.info("Writing example {}".format(ex_index))
input_features = self.preprocessor.get_input_features(example, labelled=labelled, priming=priming)
if self.task_helper:
self.task_helper.add_special_input_features(example, input_features)
features.append(input_features)
if ex_index < 5:
logger.info(f'--- Example {ex_index} ---')
logger.info(input_features.pretty_print(self.tokenizer))
return features
def _mask_tokens(self, input_ids):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = input_ids.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability 0.15)
probability_matrix = torch.full(labels.shape, 0.15)
special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
# if a version of transformers < 2.4.0 is used, -1 is the expected value for indices to ignore
if [int(v) for v in transformers_version.split('.')][:3] >= [2, 4, 0]:
ignore_value = -100
else:
ignore_value = -1
labels[~masked_indices] = ignore_value # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return input_ids, labels
def generate_default_inputs(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Generate the default inputs required by almost every language model."""
inputs = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
if self.config.model_type in ['bert', 'xlnet']:
inputs['token_type_ids'] = batch['token_type_ids']
return inputs
def mlm_train_step(self, labeled_batch: Dict[str, torch.Tensor],
unlabeled_batch: Optional[Dict[str, torch.Tensor]] = None, lm_training: bool = False,
alpha: float = 0, **_) -> torch.Tensor:
"""Perform a MLM training step."""
inputs = self.generate_default_inputs(labeled_batch)
mlm_labels, labels = labeled_batch['mlm_labels'], labeled_batch['labels']
outputs = self.model(**inputs)
prediction_scores = self.preprocessor.pvp.convert_mlm_logits_to_cls_logits(mlm_labels, outputs[0])
loss = nn.CrossEntropyLoss()(prediction_scores.view(-1, len(self.config.label_list)), labels.view(-1))
if lm_training:
lm_inputs = self.generate_default_inputs(unlabeled_batch)
lm_inputs['masked_lm_labels'] = unlabeled_batch['mlm_labels']
lm_loss = self.model(**lm_inputs)[0]
loss = alpha * loss + (1 - alpha) * lm_loss
return loss
def plm_train_step(self, labeled_batch: Dict[str, torch.Tensor], lm_training: bool = False, **_):
"""Perform a PLM training step."""
inputs = self.generate_default_inputs(labeled_batch)
inputs['perm_mask'], inputs['target_mapping'] = labeled_batch['perm_mask'], labeled_batch['target_mapping']
labels = labeled_batch['labels']
outputs = self.model(**inputs)
prediction_scores = self.preprocessor.pvp.convert_plm_logits_to_cls_logits(outputs[0])
loss = nn.CrossEntropyLoss()(prediction_scores.view(-1, len(self.config.label_list)), labels.view(-1))
if lm_training:
raise NotImplementedError("Language model training is currently not implemented for PLMs")
return loss
def sequence_classifier_train_step(self, batch: Dict[str, torch.Tensor], use_logits: bool = False,
temperature: float = 1, **_) -> torch.Tensor:
"""Perform a sequence classifier training step."""
inputs = self.generate_default_inputs(batch)
if not use_logits:
inputs['labels'] = batch['labels']
outputs = self.model(**inputs)
if use_logits:
logits_predicted, logits_target = outputs[0], batch['logits']
return distillation_loss(logits_predicted, logits_target, temperature)
else:
return outputs[0]
def mlm_eval_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Perform a MLM evaluation step."""
inputs = self.generate_default_inputs(batch)
outputs = self.model(**inputs)
return self.preprocessor.pvp.convert_mlm_logits_to_cls_logits(batch['mlm_labels'], outputs[0])
def plm_eval_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Perform a PLM evaluation step."""
inputs = self.generate_default_inputs(batch)
inputs['perm_mask'], inputs['target_mapping'] = batch['perm_mask'], batch['target_mapping']
outputs = self.model(**inputs)
return self.preprocessor.pvp.convert_plm_logits_to_cls_logits(outputs[0])
def sequence_classifier_eval_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Perform a sequence classifier evaluation step."""
inputs = self.generate_default_inputs(batch)
return self.model(**inputs)[0]