<a href="https://colab.research.google.com/github/pa-shk/embeddings_mixup/blob/main/embed-mixup-model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets -q
!pip install torchmetrics -q
!pip install wandb -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.2/840.2 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m196.4/196.4 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.1/254.1 kB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import random
import os
import torch
import transformers
import numpy as np

from typing import Dict, List, Union


def seed_everything(seed: int) -> None:
    """
    Ensures reproducibility
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    transformers.set_seed(seed)


seed_everything(42)

# Dataset preparation

In [None]:
MODEL_NAME = 'bert-base-cased'
MAX_LEN = 70
NUM_CLASSES = 2
BATCH_SIZE = 256

In [None]:
from datasets import load_dataset, formatting
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

In [None]:
from datasets import load_dataset

dataset = load_dataset("rotten_tomatoes")
dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/699k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/90.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/92.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8530 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1066 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1066 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 8530
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
})

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def tokenize_function(examples: Union[formatting.formatting.LazyRow, formatting.formatting.LazyBatch]
                      ) -> Dict[str, List[Union[int, List[int]]]]:
    """
    Performs tokenization of the dataset
    """
    return tokenizer(examples['text'],
                     padding='max_length',
                     truncation=True,
                     max_length=MAX_LEN)

tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [None]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
tokenized_datasets.set_format('torch')
tokenized_datasets

Map:   0%|          | 0/8530 [00:00<?, ? examples/s]

Map:   0%|          | 0/1066 [00:00<?, ? examples/s]

Map:   0%|          | 0/1066 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 8530
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1066
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1066
    })
})

In [None]:
train_dataloader = DataLoader(tokenized_datasets['train'], shuffle=True, batch_size=BATCH_SIZE)
eval_dataloader = DataLoader(tokenized_datasets['validation'], batch_size=BATCH_SIZE)

# Model finetuning

In [None]:
from transformers import AutoModelForSequenceClassification, get_scheduler
from tqdm.auto import tqdm
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
import wandb
from huggingface_hub import notebook_login

## Useful functions

In [None]:
def process_batch(embedder_model: torch.nn.modules.sparse.Embedding,
                  batch: Dict[str, torch.Tensor],
                  gamma: float,
                  device: str='cpu') -> Dict[str, Dict[str, torch.Tensor]]:
    """
    Performs embeddings and labels MixUp
    """
    batch_size = batch['labels'].shape[0]

    embeddings  = embedder_model(batch['input_ids'])
    embeddings = gamma * embeddings + (1 - gamma) * embeddings.flip(dims=[0])

    token_type_ids = batch['token_type_ids']

    attention_mask = batch['attention_mask']
    attention_mask = (attention_mask == 1) + (attention_mask.flip(dims=[0]) == 1)

    labels = torch.zeros(batch_size, NUM_CLASSES, device=device)
    labels[torch.arange(batch_size), batch['labels']] = 1
    labels = gamma * labels + (1 - gamma) * labels.flip(dims=[0])

    return {'texts': {'inputs_embeds': embeddings,
                      'token_type_ids': token_type_ids,
                      'attention_mask': attention_mask},
            'labels': labels}


def compute_metrics(model: torch.nn.Module,
                    dataloader: torch.utils.data.dataloader.DataLoader,
                    mode: str,
                    device: str='cpu')-> Dict[str, float]:
    """
    Computes metrics either on training or validation set
    """
    metrics = {'accuracy': BinaryAccuracy().to(device),
               'precision': BinaryPrecision().to(device),
               'recall': BinaryRecall().to(device)}

    model.eval()
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        for metric in metrics.values():
            metric(preds=predictions, target=batch['labels'])

    return {f'{mode}_{metric_name}': metric.compute() for metric_name, metric in metrics.items()}


def train(model: torch.nn.Module,
          num_epochs: int,
          optimizer: torch.optim.Optimizer,
          num_training_steps: int,
          lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
          criterion: torch.nn.modules.loss._Loss,
          grad_norm: float=1.0,
          device: str='cpu') -> None:
    """
    Executes training loop with valudation during training
    """
    progress_bar = tqdm(range(num_training_steps))
    model.train()
    embedder = model.bert.get_input_embeddings()

    for epoch in range(num_epochs):
        for batch in train_dataloader:

            gamma = np.random.normal(loc=0.1, scale=0.015)

            batch = {k: v.to(device) for k, v in batch.items()}
            batch = process_batch(embedder_model=embedder,
                                  batch=batch,
                                  gamma=gamma,
                                  device=device)

            outputs = model(**batch['texts'])

            loss = criterion(outputs['logits'], batch['labels'])
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
            optimizer.zero_grad()
            progress_bar.update(1)

        train_metrics = compute_metrics(model, train_dataloader, mode='train', device=device)
        wandb.log(train_metrics)
        eval_metrics = compute_metrics(model, eval_dataloader, mode='eval', device=device)
        wandb.log(eval_metrics)

## Model configuration

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_CLASSES)


num_epochs = 10

lr = 5e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

num_training_steps = num_epochs * len(train_dataloader)
num_warmup_steps = 100
lr_scheduler = get_scheduler(name='linear',
                             optimizer=optimizer,
                             num_warmup_steps=num_warmup_steps,
                             num_training_steps=num_training_steps)

criterion = torch.nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Fine-tuning

In [None]:
wandb.init(project='bert-base-cased-mixup')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
model = model.to(device)


train(model=model,
      num_epochs=num_epochs,
      optimizer=optimizer,
      num_training_steps=num_training_steps,
      lr_scheduler=lr_scheduler,
      criterion=criterion,
      device=device)

  0%|          | 0/340 [00:00<?, ?it/s]

In [None]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
model.push_to_hub('pa-shk/bert-base-cased-embed-mixup')

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/pa-shk/bert-base-cased-embed-mixup/commit/955fdb38a86e8238f863f8aad3b394b016bbb3de', commit_message='Upload BertForSequenceClassification', commit_description='', oid='955fdb38a86e8238f863f8aad3b394b016bbb3de', pr_url=None, pr_revision=None, pr_num=None)

In [None]:
tokenizer.push_to_hub('pa-shk/bert-base-cased-embed-mixup')

CommitInfo(commit_url='https://huggingface.co/pa-shk/bert-base-cased-embed-mixup/commit/49ac3558996dedb152769c788eeb57f47cc5caea', commit_message='Upload tokenizer', commit_description='', oid='49ac3558996dedb152769c788eeb57f47cc5caea', pr_url=None, pr_revision=None, pr_num=None)