In [63]:
# @title Imports
from transformers import AutoTokenizer, AutoModel ,DataCollatorWithPadding
from datasets import load_dataset
from torchsummary import summary
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn as nn
import torch
import time

In [83]:
# @title Global variables
checkpoint="google-bert/bert-base-uncased"
device="cuda" if torch.cuda.is_available() else "cpu"
epochs=1

In [65]:
# @title Loading Tokenizer and model
base_model=AutoModel.from_pretrained(checkpoint)
tokenizer=AutoTokenizer.from_pretrained(checkpoint)

In [66]:
tokenizer("This model is trained on Google Collab")

{'input_ids': [101, 2023, 2944, 2003, 4738, 2006, 8224, 8902, 20470, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [67]:
base_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [68]:
def tokenize_function(examples):
  # 512 is the max length for bert model
  return tokenizer(examples['text'],truncation=True)

In [69]:
# freezing the base model
for param in base_model.parameters():
  param.requires_grad=False

In [70]:
#@title Classification
dataset_name="stanfordnlp/imdb"
dataset=load_dataset(dataset_name)

In [71]:
tokenized_dataset=dataset.map(tokenize_function,batched=True)

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [72]:
train_dataset=tokenized_dataset['train'].remove_columns(['text']).rename_column("label", "labels")
test_dataset=tokenized_dataset['test'].remove_columns(['text']).rename_column("label", "labels")

In [73]:
#@title Train,Test loader
train_loader=DataLoader(train_dataset,collate_fn=DataCollatorWithPadding(tokenizer=tokenizer,padding='longest'),batch_size=8,shuffle=True)
test_loader=DataLoader(test_dataset,collate_fn=DataCollatorWithPadding(tokenizer=tokenizer,padding='longest'),batch_size=8,shuffle=True)

In [74]:
for batch in train_loader:
    break
{k: v.shape for k, v in batch.items()}

{'labels': torch.Size([8]),
 'input_ids': torch.Size([8, 402]),
 'token_type_ids': torch.Size([8, 402]),
 'attention_mask': torch.Size([8, 402])}

In [75]:
batch['input_ids'].shape

torch.Size([8, 402])

In [76]:
del batch['labels']

In [77]:
_,outputs=base_model(**batch,return_dict=False)

In [78]:
outputs.shape

torch.Size([8, 768])

In [79]:
# @title Classification Head
class Classification(nn.Module):
  def __init__(self,base_model,drp_rate=0.3):
    super().__init__()
    self.base_model=base_model
    self.classification_head=nn.Sequential(
        nn.Linear(768,512),
        nn.ReLU(),
        nn.Dropout(drp_rate),
        nn.Linear(512,2))

  def forward(self,batch):
    _,outputs=self.base_model(**batch,return_dict=False)
    return self.classification_head(outputs)

In [80]:
# @title Loss fn and optimizer
loss_fn=nn.CrossEntropyLoss()
model=Classification(base_model).to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=1e-5)

In [81]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

394754

In [84]:
# @title training
for epoch in range(epochs):
  tot_loss=0
  tot_correct=0
  tot_samples=0
  start_time=time.time()
  for batch in tqdm(train_loader):
    batch={k:v.to(device) for k,v in batch.items()}
    labels=batch['labels']
    del batch['labels']
    outputs=model(batch)
    loss=loss_fn(outputs,labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    tot_loss+=loss.item()
    tot_samples+=labels.shape[0]
    tot_correct+=(outputs.argmax(dim=-1)==labels).sum().item()
  print(f"Epoch no {epoch+1}, total_loss= {loss} ,  accuracy = {tot_correct/tot_samples}")

100%|██████████| 3125/3125 [11:50<00:00,  4.40it/s]

Epoch no 1, total_loss= 0.5692952871322632 ,  accuracy = 0.7342





In [None]:
# @title testing
tot_loss=0
tot_correct=0
tot_samples=0
start_time=time.time()
for batch in tqdm(test_loader):
    batch={k:v.to(device) for k,v in batch.items()}
    labels=batch['labels']
    del batch['labels']
    with torch.no_grad():
      outputs=model(batch)
    loss=loss_fn(outputs,labels)
    tot_loss+=loss.item()
    tot_samples=labels.shape[0]
    tot_correct+=(outputs.argmax(dim=-1)==labels).sum().item()
print(f"Epoch no {epoch+1}, total_loss= {loss} ,  accuracy = {tot_correct/tot_samples}")