# How Trainer works

In [28]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from torch.utils.data import DataLoader

In [45]:
## Data Preprocessing ##
data = load_dataset('glue', 'mrpc')

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

tokenized_data = data.map(lambda x: tokenizer(x['sentence1'], x['sentence2'], truncation=True))

data_collator = DataCollatorWithPadding(tokenizer) 

### DataLoader ###
# {k:v for k,v in tokenized_data.items() if k not in ['sentence1','sentence2','idx']}

tokenized_data = tokenized_data.remove_columns(['sentence1','sentence2','idx'])
tokenized_data.set_format('torch')

train_loader = DataLoader(tokenized_data['train'],shuffle=True, batch_size=8,collate_fn=data_collator)

eval_loader = DataLoader(tokenized_data['validation'], shuffle=True, batch_size=8, collate_fn=data_collator)

for batch in train_loader:
    break
{k:v.shape for k,v in batch.items()}

{'input_ids': torch.Size([8, 73]),
 'token_type_ids': torch.Size([8, 73]),
 'attention_mask': torch.Size([8, 73]),
 'labels': torch.Size([8])}

In [60]:
## Model ##
from transformers import AutoModelForSequenceClassification, get_scheduler
from torch.optim import AdamW
import torch


classifier = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')

test_output = classifier(**batch)

print('loss:\n', test_output.loss, '\nlogits:\n', test_output.logits.shape)
# probability of classes
# torch.nn.functional.softmax(test_output.logits)

optimizer = AdamW(classifier.parameters(), lr=1e-5, weight_decay=0.01) #bitesandbyter for memory-efficient optimization, lower lr 

## lr scheduler = linear decay from 5e-5-0. training step = # epochs*training batches (len(dataloader))
num_epochs = 3
num_training_steps = num_epochs * len(train_loader)

lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps)
print(num_training_steps)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


loss:
 tensor(1.0306, grad_fn=<NllLossBackward0>) 
logits:
 torch.Size([8, 2])
1377


In [68]:
## Training Loop (undistributed) ##
import torch
from tqdm.auto import tqdm

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

classifier.to(device)

progress = tqdm(range(num_training_steps))

classifier.train()

for epoch in range(num_epochs):
    for batch in train_loader:
        batch = {k:v.to(device) for k,v in batch.items()}
        outputs = classifier(**batch)
        loss = outputs.loss
        loss.backward() # backpropagation (computes gradients w.r.t. weights)

        optimizer.step() # update model patameters (weights)
        lr_scheduler.step() # adjust learning rate
        optimizer.zero_grad() # reset all gradients
        progress.update(1) # advance progress bar by 1

cuda


  0%|          | 0/1377 [00:00<?, ?it/s]

In [93]:
## training loop (distributed) ##

from accelerate import Accelerator
from torch.optim import AdamW
from transformers import AutoTokenizer , AutoModelForSequenceClassification, DataCollatorWithPadding, get_scheduler
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from tqdm.auto import tqdm

In [90]:
## GET DATA

data = load_dataset('glue', 'mrpc')

token = AutoTokenizer.from_pretrained('bert-base-uncased')

tok_data = data.map(lambda x: token(x['sentence1'], x['sentence2'], truncation=True))

input_data = tok_data.remove_columns(['sentence1', 'sentence2', 'idx'])

collate = DataCollatorWithPadding(tokenizer=token)

train_dataloader = DataLoader(input_data['train'], shuffle=True, collate_fn=collate)
eval_dataloader = DataLoader(input_data['validation'], shuffle=True, collate_fn=collate)

classifier_model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
opt = AdamW(classifier_model.parameters(), lr=1e-5)

## Distributed trainng
accelerator = Accelerator() # initialize the proper distributed setup

acc_train_dl, acc_eval_dl, acc_cls_model, acc_opt = accelerator.prepare(train_dataloader, eval_dataloader, classifier_model, opt)

for acc_batch in acc_train_dl:
    break
{k:v for k,v in acc_batch.items()}

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'input_ids': tensor([[ 101, 2009, 6749, 2008, 9584, 2022, 2445, 2000, 2437, 1996, 3105, 1997,
          9915, 2472, 1037, 2476, 2695, 2612, 1997, 1037, 2576, 6098, 1012,  102,
          2009, 6083, 2008, 3519, 5136, 2437, 1996, 9915, 2472, 1037, 2658, 2597,
          1010, 2738, 2084, 1037, 2576, 6098, 1012,  102]], device='cuda:0'),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
        device='cuda:0'),
 'labels': tensor([1], device='cuda:0')}

In [95]:
## Training loop

epochs = 3
training_step = epochs * len(acc_train_dl)
lr_scheduler = get_scheduler('linear', optimizer=acc_opt, num_warmup_steps=0, num_training_steps=training_step)
progress_bar = tqdm(range(training_step))

## Trainng

acc_cls_model.train()

for epoch in range(epochs):
    for batch in acc_train_dl:
        output = acc_cls_model(**batch)

        loss = output.loss
        accelerator.backward(loss)

        acc_opt.step()
        lr_scheduler.step()
        acc_opt.zero_grad()
        progress_bar.update(1)



  0%|          | 0/11004 [00:00<?, ?it/s]