Skip to content

Commit

Permalink
Rework TF trainer (huggingface#6038)
Browse files Browse the repository at this point in the history
* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import
  • Loading branch information
jplu committed Jul 29, 2020
1 parent 3f94170 commit 54f9fbe
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 215 deletions.
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Examples

Version 2.9 of 馃 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.

Here is the list of all our examples:
- **grouped by task** (all official examples work for multiple models)
Expand Down
2 changes: 2 additions & 0 deletions examples/multiple-choice/utils_multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def gen():
)

def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))

return self.dataset

def __len__(self):
Expand Down
11 changes: 9 additions & 2 deletions examples/question-answering/run_tf_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from dataclasses import dataclass, field
from typing import Optional

import tensorflow as tf

from transformers import (
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -68,6 +70,7 @@ class DataTrainingArguments:
data_dir: Optional[str] = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
)
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
max_seq_length: int = field(
default=128,
metadata={
Expand Down Expand Up @@ -170,7 +173,7 @@ def main():
)

# Get datasets
if not data_args.data_dir:
if data_args.use_tfds:
if data_args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")

Expand All @@ -179,7 +182,7 @@ def main():
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")

tfds_examples = tfds.load("squad")
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
train_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
if training_args.do_train
Expand Down Expand Up @@ -209,6 +212,8 @@ def main():
else None
)

train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))

eval_dataset = (
squad_convert_examples_to_features(
examples=eval_examples,
Expand All @@ -223,6 +228,8 @@ def main():
else None
)

eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))

# Initialize our Trainer
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)

Expand Down
27 changes: 22 additions & 5 deletions examples/text-classification/run_tf_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, Optional

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from transformers import (
Expand All @@ -35,7 +36,11 @@ class Split(Enum):


def get_tfds(
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
task_name: str,
tokenizer: PreTrainedTokenizer,
max_seq_length: Optional[int] = None,
mode: Split = Split.train,
data_dir: str = None,
):
if task_name == "mnli-mm" and mode == Split.dev:
tfds_name = "mnli_mismatched"
Expand All @@ -50,9 +55,11 @@ def get_tfds(
else:
tfds_name = task_name

ds = tfds.load("glue/" + tfds_name, split=mode.value)
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))

return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
return ds


logger = logging.getLogger(__name__)
Expand All @@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
"""

task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
max_seq_length: int = field(
default=128,
metadata={
Expand Down Expand Up @@ -171,13 +179,22 @@ def main():

# Get datasets
train_dataset = (
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
get_tfds(
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
data_dir=data_args.data_dir,
)
if training_args.do_train
else None
)
eval_dataset = (
get_tfds(
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
mode=Split.dev,
data_dir=data_args.data_dir,
)
if training_args.do_eval
else None
Expand Down
6 changes: 0 additions & 6 deletions examples/token-classification/run_tf_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import logging
import os
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -185,11 +184,6 @@ def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[L

for i in range(batch_size):
for j in range(seq_len):
if label_ids[i, j] == -1:
label_ids[i, j] = -100
warnings.warn(
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
)
if label_ids[i, j] != -100:
out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]])
Expand Down
4 changes: 3 additions & 1 deletion examples/token-classification/utils_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class TFNerDataset:
"""

features: List[InputFeatures]
pad_token_label_id: int = -1
pad_token_label_id: int = -100
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.

Expand Down Expand Up @@ -221,6 +221,8 @@ def gen():
)

def get_dataset(self):
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))

return self.dataset

def __len__(self):
Expand Down
7 changes: 1 addition & 6 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import functools
import logging
import os
import warnings
from typing import Dict, List, Optional, Union

import h5py
Expand Down Expand Up @@ -174,11 +173,7 @@ def compute_loss(self, labels, logits):
)
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True:
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

Expand Down
Loading

0 comments on commit 54f9fbe

Please sign in to comment.