In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
import json
from transformers import pipeline
import torch
from transformers import DistilBertForQuestionAnswering, TrainingArguments, Trainer, DefaultDataCollator, BitsAndBytesConfig
from peft import LoraConfig, PeftModel
import os
import wandb
from trl import SFTTrainer




In [2]:
# Loading the dataset
# The MRQA dataset is included in huggingface's datasets library, so we just have to load it
mrqa = load_dataset("mrqa", split="train[:20%]")
# Creating the train-test-validation split
mrqa = mrqa.train_test_split(test_size=0.2)
mrqa["train"] = mrqa["train"].train_test_split(test_size=0.2)
mrqa["val"] = mrqa["train"]["test"]
mrqa["train"] = mrqa["train"]["train"]

In [3]:
# Loading the tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-distilled-squad")

In [4]:
# Tokenizing the texts with distilbert's own pretrained tokenizer and mapping the answer start and end character indices onto the tokenized text
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    # Tokenizing inputs
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
        return_tensors="pt"
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["detected_answers"]
    start_positions = []
    end_positions = []

    # Mapping the answer start and end characters to the tokenized text
    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["char_spans"][0]["start"][0]
        end_char = answer["char_spans"][0]["end"][0]
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [5]:
tokenized_mrqa = mrqa.map(preprocess_function, batched=True, remove_columns=mrqa["train"].column_names)
tokenized_mrqa.set_format(type="torch")

Map:   0%|          | 0/66152 [00:00<?, ? examples/s]

Map:   0%|          | 0/20673 [00:00<?, ? examples/s]

Map:   0%|          | 0/16539 [00:00<?, ? examples/s]

In [6]:
data_collator = DefaultDataCollator()

In [7]:
wandb.init(project='nagyhazi_colab', entity='bme_deepl_learning')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcsanad00058[0m ([33mbme_deepl_learning[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888888721, max=1.0)…

In [8]:
bnb_4bit_compute_dtype = "float16"
num_train_epochs = 3
fp16 = False
bf16 = False
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradient_accumulation_steps = 1
gradient_checkpointing = True
max_grad_norm = 0.3
optim = "paged_adamw_32bit"
lr_scheduler_type = "constant"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 25
logging_steps = 5
max_seq_length = None
packing = False


In [27]:
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=getattr(torch, bnb_4bit_compute_dtype),
    bnb_4bit_use_double_quant=False,
)

# Configuring parameters of the low-rank adaptation
peft_config = LoraConfig(
    lora_alpha=6,
    lora_dropout=0.15,
    r=2,
    bias="none",
    task_type="QUESTION_ANS",
    target_modules=["q_lin", "k_lin", "v_lin", "ffn.lin1", "ffn.lin2", "attention.out_proj"])

In [24]:
model = DistilBertForQuestionAnswering.from_pretrained("distilbert/distilbert-base-uncased-distilled-squad",
                                                      quantization_config=bnb_config,
                                                      device_map={"": 0})

In [25]:
# Testing the model
question, text = "Who was the last pharaoh of ancient Egypt?", "The last pharaoh of ancient Egypt was Cleopatra."

inputs = tokenizer(question, text, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)

answer_start_index = torch.argmax(outputs.start_logits)
answer_end_index = torch.argmax(outputs.end_logits)

predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
tokenizer.decode(predict_answer_tokens)

'cleopatra'

In [29]:
output_dir_name = "trial_run_local"

training_args = TrainingArguments(
    output_dir=output_dir_name,
    eval_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='new_dir',
    push_to_hub=False,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_mrqa["train"],
    eval_dataset=tokenized_mrqa["val"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    peft_config=peft_config
)

trainer.train()
wandb.finish()



  0%|          | 0/12405 [00:00<?, ?it/s]

{'loss': 1.1964, 'grad_norm': 2.6373815536499023, 'learning_rate': 1.919387343812979e-05, 'epoch': 0.12}


  0%|          | 0/1034 [00:00<?, ?it/s]

{'eval_loss': 1.0556285381317139, 'eval_runtime': 659.7168, 'eval_samples_per_second': 25.07, 'eval_steps_per_second': 1.567, 'epoch': 0.12}
{'loss': 1.1438, 'grad_norm': 3.6368579864501953, 'learning_rate': 1.8387746876259574e-05, 'epoch': 0.24}


  0%|          | 0/1034 [00:00<?, ?it/s]

{'eval_loss': 1.010238766670227, 'eval_runtime': 659.2553, 'eval_samples_per_second': 25.087, 'eval_steps_per_second': 1.568, 'epoch': 0.24}


KeyboardInterrupt: 

In [16]:
model.modules

<bound method Module.modules of DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear4bit(in_features=768, out_features=768, bias=True)
            (k_lin): Linear4bit(in_features=768, out_features=768, bias=True)
            (v_lin): Linear4bit(in_features=768, out_features=768, bias=True)
            (out_lin): Linear4bit(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
          

In [22]:
len([module for module in model.modules()])

95