In [1]:
%%capture
!pip install datasets transformers==4.28.0 evaluate

In [2]:
%%capture
!rm -rf phd_sentence_semantic_models
!git clone https://github.com/vrublevskiyvitaliy/phd_sentence_semantic_models.git

In [3]:
import torch
import evaluate

from transformers import AutoTokenizer, BertConfig, DataCollatorWithPadding, get_scheduler
from transformers.utils import PaddingStrategy
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import TruncationStrategy

from datasets import load_dataset, load_metric
from functools import partial

from tqdm.auto import tqdm

from torch.optim import AdamW

# Mine from repo
from phd_sentence_semantic_models.models.bert_model_classic import BertForSequenceClassificationClassic
from phd_sentence_semantic_models.models.bert_tokeniser_with_pos_tags import preprocess_dataset_with_pos_tags
from phd_sentence_semantic_models.utils.seed import init_seed
from phd_sentence_semantic_models.utils.train_eval_cycle import train_eval

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [4]:
# GLOBALS

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MAX_LEN = 512
BATCH_SIZE = 32
TRUNCATION = TruncationStrategy.LONGEST_FIRST
PADDING=PaddingStrategy.MAX_LENGTH
SEED = 42
LR = 2e-5
NUM_TRAIN_EPOCHS = 10

model_name = "bert-base-cased"

In [5]:
# INITIALISATION
init_seed(SEED)

In [6]:
model_tokenizer = AutoTokenizer.from_pretrained(model_name)

config = BertConfig.from_pretrained(model_name)
model = BertForSequenceClassificationClassic.from_pretrained(model_name,config=config)

model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassificationClassic: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassificationClassic from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassificationClassic from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassificationClassic were not initialized from the model checkpoint at bert-ba

classifier_dropout = 0.1


BertForSequenceClassificationClassic(
  (bert): BertModelClassic(
    (embeddings): BertEmbeddingsClassic(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNo

In [7]:
dataset_train = load_dataset("glue", 'mrpc', split="train")
dataset_eval = load_dataset("glue", 'mrpc', split="validation")

preprocess_dataset_with_pos_tags_full = partial(
    preprocess_dataset_with_pos_tags,
    tokenizer=model_tokenizer,
    truncation=TRUNCATION,
    max_length=MAX_LEN,
    padding=PADDING,
  )

collator = DataCollatorWithPadding(model_tokenizer)

def prepare_dataloader(dataset, collator):
  dataset = dataset.map(preprocess_dataset_with_pos_tags_full, batched=False)
  dataset = dataset.remove_columns(["sentence1", "sentence2", "idx", "pos_tag_ids"])
  dataset = dataset.rename_column("label", "labels")
  dataset.set_format("torch")
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collator)
  return dataloader


train_dataloader = prepare_dataloader(dataset_train, collator)
eval_dataloader = prepare_dataloader(dataset_eval, collator)

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [8]:
optimizer = AdamW(model.parameters(), lr=LR)

num_training_steps = NUM_TRAIN_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [9]:
train_eval(
    model = model,
    optimizer = optimizer,
    lr_scheduler = lr_scheduler,
    train_dataloader = train_dataloader,
    eval_dataloader = eval_dataloader,
    num_train_epochs = NUM_TRAIN_EPOCHS,
    num_training_steps = num_training_steps,
    device = device,
)

  0%|          | 0/1150 [00:00<?, ?it/s]

Epoch 0




Accuracy 0.7916666666666666
F1 0.8434622467771639
Epoch 1




Accuracy 0.8578431372549019
F1 0.8937728937728937
Epoch 2




Accuracy 0.8602941176470589
F1 0.9025641025641027
Epoch 3




Accuracy 0.8529411764705882
F1 0.900990099009901
Epoch 4




Accuracy 0.8578431372549019
F1 0.8996539792387542
Epoch 5




Accuracy 0.8627450980392157
F1 0.9037800687285222
Epoch 6




Accuracy 0.8676470588235294
F1 0.9078498293515359
Epoch 7




Accuracy 0.8602941176470589
F1 0.9038785834738617
Epoch 8




Accuracy 0.8651960784313726
F1 0.9063032367972743
Epoch 9




Accuracy 0.8676470588235294
F1 0.9078498293515359


In [None]:
# Epoch 0
# Accuracy 0.7916666666666666
# F1 0.8434622467771639

# Epoch 1
# Accuracy 0.8578431372549019
# F1 0.8937728937728937

# Epoch 2
# Accuracy 0.8602941176470589
# F1 0.9025641025641027

# Epoch 3
# Accuracy 0.8529411764705882
# F1 0.900990099009901

# Epoch 4
# Accuracy 0.8578431372549019
# F1 0.8996539792387542

# Epoch 5
# Accuracy 0.8627450980392157
# F1 0.9037800687285222

# Epoch 6
# Accuracy 0.8676470588235294
# F1 0.9078498293515359

# Epoch 7
# Accuracy 0.8602941176470589
# F1 0.9038785834738617

# Epoch 8
# Accuracy 0.8651960784313726
# F1 0.9063032367972743

# Epoch 9
# Accuracy 0.8676470588235294
# F1 0.9078498293515359