In [1]:
#https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb

In [2]:
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import Trainer, TrainingArguments
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
from transformers import get_scheduler

from torch.optim.lr_scheduler import OneCycleLR

import torch
from torch.utils.data import DataLoader

import numpy as np
from tqdm.auto import tqdm

import datasets

from accelerate import Accelerator
from accelerate import notebook_launcher
from accelerate.utils import set_seed

In [3]:
checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint) 

data_collator = DataCollatorWithPadding(tokenizer)

In [4]:
def tokenize_function(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True)

def get_dataloaders(batch_size: int=64):
    raw_datasets = load_dataset("glue", "mrpc")
    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
    tokenized_datasets = tokenized_datasets.remove_columns(['sentence1', 'sentence2', 'idx'])
    tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
    tokenized_datasets.set_format('torch')
    
    train_dataloader = DataLoader(tokenized_datasets['train'],
                             shuffle=True,
                             batch_size=batch_size,
                             collate_fn=data_collator)
    eval_dataloader = DataLoader(tokenized_datasets['validation'],                             
                                 batch_size=batch_size,
                                 collate_fn=data_collator)
    
    return train_dataloader, eval_dataloader, tokenized_datasets

In [5]:
def training_loop(seed: int = 42, batch_size: int = 64):
    
    set_seed(seed)
    
    # Initialize accelerator
    accelerator = Accelerator()    
        
    # Build dataloaders
    train_dataloader, eval_dataloader, tokenized_datasets = get_dataloaders(batch_size)
    
    checkpoint = "bert-base-cased"
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
    
    # Intantiate the optimizer
    optimizer = AdamW(model.parameters(), lr=5e-5)
    
    num_epochs = 10
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_scheduler(
      "linear",
      optimizer=optimizer,
      num_warmup_steps=0,
      num_training_steps=num_training_steps
    )
    
    train_dataloader, eval_dataloader, model, optimizer, lr_scheduler = \
    accelerator.prepare(train_dataloader, eval_dataloader, model, optimizer, lr_scheduler)
    
    #progress_bar = tqdm(range(num_training_steps), disable=not accelerator.is_main_process)
    metric = load_metric("glue", "mrpc")

    for epoch in range(num_epochs):
        model.train()
        for batch in train_dataloader:        
            outputs = model(**batch)
            loss = outputs.loss        
            accelerator.backward(loss)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            #progress_bar.update(1)
        
        model.eval()
        all_predictions = []
        all_labels = []

        for batch in eval_dataloader:

            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)

            all_predictions.append(accelerator.gather(predictions))
            all_labels.append(accelerator.gather(batch["labels"]))

        all_predictions = torch.cat(all_predictions)[:len(tokenized_datasets["validation"])]
        all_labels = torch.cat(all_labels)[:len(tokenized_datasets["validation"])]

        eval_metric = metric.compute(predictions=all_predictions, references=all_labels)

        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}:", eval_metric)

In [6]:
args = (42, 64)
notebook_launcher(training_loop, args, num_processes=2)

Launching a training on 2 GPUs.




  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

epoch 0: {'accuracy': 0.7156862745098039, 'f1': 0.81875}
epoch 1: {'accuracy': 0.8333333333333334, 'f1': 0.8839590443686006}
epoch 2: {'accuracy': 0.8357843137254902, 'f1': 0.8784029038112522}
epoch 3: {'accuracy': 0.8529411764705882, 'f1': 0.896551724137931}
epoch 4: {'accuracy': 0.8504901960784313, 'f1': 0.897133220910624}
epoch 5: {'accuracy': 0.8553921568627451, 'f1': 0.899488926746167}
epoch 6: {'accuracy': 0.8553921568627451, 'f1': 0.8963093145869947}
epoch 7: {'accuracy': 0.8406862745098039, 'f1': 0.8929159802306424}
epoch 8: {'accuracy': 0.8308823529411765, 'f1': 0.8851913477537438}
epoch 9: {'accuracy': 0.8455882352941176, 'f1': 0.8884955752212389}


In [7]:
# Build dataloaders
train_dataloader, eval_dataloader, tokenized_datasets = get_dataloaders(batch_size=64)

Reusing dataset glue (/home/jupyter/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ece218634f2dd18d.arrow
Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-cbdd35d6ef53475e.arrow
Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2489abcaef97d751.arrow
