### This notebook is based on the [course](https://huggingface.co/course/en/chapter3/4)  from hugging face.
* An external RTX 3060 Ti card was used during the retraining. 

In [1]:
import random, numpy as np, torch
random_seed = 8570
random.seed           (random_seed) 
np.random.seed        (random_seed)
torch.manual_seed     (random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True  

import datasets, evaluate, tqdm
from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification, \
                         DataCollatorWithPadding, get_scheduler

checkpoint    = "bert-base-uncased"   
num_epochs    = 3
batch_size    = 8
learning_rate = 5e-5

tokenizer = AutoTokenizer.from_pretrained(checkpoint) 
raw_datasets = datasets.load_dataset("glue", "mrpc")
tk_func = lambda e: tokenizer(e["sentence1"], e["sentence2"], truncation=True) 
tokenized_datasets = raw_datasets.map(tk_func, batched=True) 
tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")  

# We use collator because we need smart batch padding instead of dataset padding.
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_dataloader = torch.utils.data.DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
)
eval_dataloader = torch.utils.data.DataLoader(
    tokenized_datasets["validation"], batch_size=batch_size, collate_fn=data_collator
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)  
model.to(device) 
optimizer = AdamW(model.parameters(), lr=learning_rate) 

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

progress_bar = tqdm.auto.tqdm(range(num_training_steps))

metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
for epoch in range(num_epochs):
    
    model.train()
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        progress_bar.update(1)

    model.eval() 
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
    
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1) 
        metrics.add_batch(predictions=predictions, references=batch["labels"])  
    print(metrics.compute()) 

Found cached dataset glue (/home/a/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/a/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e79368f4ad892c49.arrow
Loading cached processed dataset at /home/a/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-1a159c2082038660.arrow
Loading cached processed dataset at /home/a/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dbe1e3677e81ec9a.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing

  0%|          | 0/1377 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'accuracy': 0.7598039215686274, 'f1': 0.8327645051194539, 'precision': 0.7947882736156352, 'recall': 0.8745519713261649}
{'accuracy': 0.8259803921568627, 'f1': 0.8743362831858408, 'precision': 0.8636363636363636, 'recall': 0.8853046594982079}
{'accuracy': 0.8431372549019608, 'f1': 0.8900343642611684, 'precision': 0.8547854785478548, 'recall': 0.9283154121863799}
