In [25]:
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModelForMaskedLM
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling
from transformers import EarlyStoppingCallback, IntervalStrategy
import math
from datasets import load_dataset

In [2]:
model_checkpoint = "bert-base-cased"
tokenizer_checkpoint = "sgugger/bert-like-tokenizer"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
config = AutoConfig.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_config(config)

In [24]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [3]:
# Load dataset

In [4]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [5]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [6]:
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
block_size = 128
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

Reusing dataset wikitext (/home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-8894a488dc32f625.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-43e5df8b2a38c636.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-02c103cff1700ff5.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-a234d08c184f4488.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-f2036a58c7e304a1.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-8892c521f6edacbd.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-35b9c677b89109d7.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-44cf7ecb587cd855.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-1b48c836ba41cc67.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-f7c381fe1c8594d2.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-3247d37fb988de49.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-9e758fa6a2261f0d.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-b7cef5533bc3f867.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-b03d599b6e08b907.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-247f09cc7d32c0d8.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-907559507c8fe830.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-15246fc09d60eee0.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-ef6d427b40025b80.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0b9d236b7640c2a6.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-ce422694cd626170.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-d6c3e24e9aab2732.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-43bf8fdaa77f4377.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-79f13282b0a07438.arrow


 

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-8f9f0b367ae465be.arrow


In [19]:
lm_datasets

DatasetDict({
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2301
    })
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 18761
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2009
    })
})

In [21]:
NGPU = 8
GRADACCUM = 2
EPOCHS=3
TRAIN_SIZE = len(lm_datasets["train"])
EVAL_SIZE = len(lm_datasets["validation"])
batch_size = 32*NGPU*GRADACCUM
total_steps = TRAIN_SIZE/batch_size * EPOCHS
print('Train size:', TRAIN_SIZE,', Eval size:',EVAL_SIZE, ', Batch:', batch_size,', Steps:',total_steps)

Train size: 18761 , Eval size: 2009 , Batch: 512 , Steps: 109.927734375


In [22]:
training_args = TrainingArguments(
    f"{model_checkpoint}-wikitext2",
    #evaluation_strategy = "epoch",
    evaluation_strategy = IntervalStrategy.STEPS,
    num_train_epochs=EPOCHS, #default 3
    per_device_train_batch_size=32, #default 8
    per_device_eval_batch_size=16, #default 8
    gradient_accumulation_steps=GRADACCUM, #default 1
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=False,
    logging_steps=20,
    load_best_model_at_end=True
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

using `logging_steps` to initialize `eval_steps` to 20
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [23]:
trainer.train()

***** Running training *****
  Num examples = 18761
  Num Epochs = 3
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 512
  Gradient Accumulation steps = 2
  Total optimization steps = 111


Step,Training Loss,Validation Loss
20,2.8482,1.95966
40,2.4263,1.568018
60,2.114,1.318767
80,1.9182,1.173118
100,1.8119,1.103778


***** Running Evaluation *****
  Num examples = 2009
  Batch size = 128
***** Running Evaluation *****
  Num examples = 2009
  Batch size = 128
***** Running Evaluation *****
  Num examples = 2009
  Batch size = 128
***** Running Evaluation *****
  Num examples = 2009
  Batch size = 128
***** Running Evaluation *****
  Num examples = 2009
  Batch size = 128


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=111, training_loss=2.178767865842527, metrics={'train_runtime': 220.257, 'train_samples_per_second': 255.533, 'train_steps_per_second': 0.504, 'total_flos': 3703423157830656.0, 'train_loss': 2.178767865842527, 'epoch': 3.0})

In [26]:
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

***** Running Evaluation *****
  Num examples = 2009
  Batch size = 128


Perplexity: 2.98


In [27]:
eval_results

{'eval_loss': 1.0930731296539307,
 'eval_runtime': 6.6933,
 'eval_samples_per_second': 300.149,
 'eval_steps_per_second': 2.39,
 'epoch': 3.0}