## Setup

In [1]:
!pip install -U bitsandbytes trl peft datasets



## Imports

In [2]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from trl import SFTTrainer

import os
import torch
from peft import LoraConfig

## Configuration

In [3]:
batch_size = 2
num_workers = os.cpu_count()
max_steps = -1
epochs = 3
bf16 = False
fp16 = True
gradient_accumulation_steps = 8
context_length = 1024
logging_steps = 1000
save_steps = 1000
learning_rate = 0.0002
model_name = 'facebook/opt-125m'
out_dir = 'outputs/opt_125m_squad_sft'

## Dataset Preparation

In [4]:
train_raw = load_dataset('squad', split='train')
valid_raw = load_dataset('squad', split='validation')
print(train_raw)
print(valid_raw)

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
})
Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 10570
})


In [5]:
print(train_raw[0])
print('-' * 50)
print(train_raw[1])
print('-' * 50)
print('-' * 50)
print(train_raw[2])
print('-' * 50)
print('-' * 50)
print(train_raw[3])
print('-' * 50)
print('-' * 50)
print(train_raw[4])

{'id': '5733be284776f41900661182', 'title': 'University_of_Notre_Dame', 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}
--------------------------------------------------
{'i

In [6]:
def preprocess_function(example):
    text = f"### Context:\n{example['context']}\n\n### Question:\n{example['question']}\n\n### Answer:\n{example['answers']['text'][0]}"
    return text

## Model

In [7]:
if bf16:
    model = AutoModelForCausalLM.from_pretrained(model_name).to(dtype=torch.bfloat16)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name)

  return self.fget.__get__(instance, owner)()


In [8]:
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), ep

## Tokenizer

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
    use_fast=False
)
tokenizer.pad_token = tokenizer.eos_token

## Training

In [10]:
training_args = TrainingArguments(
    output_dir=f"{out_dir}/logs",
    evaluation_strategy='epoch',
    weight_decay=0.01,
    load_best_model_at_end=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    logging_strategy='epoch',
    save_strategy='epoch',
    save_steps=save_steps,
    save_total_limit=2,
    bf16=bf16,
    fp16=fp16,
    report_to='tensorboard',
    max_steps=max_steps,
    num_train_epochs=epochs,
    dataloader_num_workers=num_workers,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    lr_scheduler_type='constant',
)

In [11]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_raw,
    eval_dataset=valid_raw,
    max_seq_length=context_length,
    tokenizer=tokenizer,
    args=training_args,
    formatting_func=preprocess_function,
    packing=True
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [12]:
dataloader = trainer.get_train_dataloader()
for i, sample in enumerate(dataloader):
    print(tokenizer.decode(sample['input_ids'][0]))
    print('#'*50)
    if i == 5:
        break

</s></s>### Context:
In 2006, the animal rights organization People for the Ethical Treatment of Animals (PETA), criticized Beyoncé for wearing and using fur in her clothing line House of Deréon. In 2011, she appeared on the cover of French fashion magazine L'Officiel, in blackface and tribal makeup that drew criticism from the media. A statement released from a spokesperson for the magazine said that Beyoncé's look was "far from the glamorous Sasha Fierce" and that it was "a return to her African roots".

### Question:
What did PETA criticize Beyonce for in 2006?

### Answer:
for wearing and using fur</s></s>### Context:
In 2006, the animal rights organization People for the Ethical Treatment of Animals (PETA), criticized Beyoncé for wearing and using fur in her clothing line House of Deréon. In 2011, she appeared on the cover of French fashion magazine L'Officiel, in blackface and tribal makeup that drew criticism from the media. A statement released from a spokesperson for the magaz

## Train

In [13]:
history = trainer.train()

Epoch,Training Loss,Validation Loss
0,1.2657,1.31877
1,1.0573,1.351102
2,0.9064,1.40918


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


In [14]:
model.save_pretrained(f"{out_dir}/best_model")
tokenizer.save_pretrained(f"{out_dir}/best_model")

('outputs/opt_125m_squad_sft/best_model/tokenizer_config.json',
 'outputs/opt_125m_squad_sft/best_model/special_tokens_map.json',
 'outputs/opt_125m_squad_sft/best_model/vocab.json',
 'outputs/opt_125m_squad_sft/best_model/merges.txt',
 'outputs/opt_125m_squad_sft/best_model/added_tokens.json')

## Inference

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

import torch

In [2]:
model = AutoModelForCausalLM.from_pretrained(
    'outputs/opt_125m_squad_sft/best_model',
    device_map='cuda'
)
tokenizer = AutoTokenizer.from_pretrained('outputs/opt_125m_squad_sft/best_model')

In [9]:
prompt = """### Context:
George Russell and Fernando Alonso have each offered their take on the incident that preceded Russell’s dramatic crash on the penultimate lap of the Australian Grand Prix following their battle for P6, with the Spaniard subsequently hit with a 20-second penalty after the race.

Russell had been chasing the Aston Martin for several laps following his final pit stop but, after getting close towards the Turn 6/7 complex, lost control of his Mercedes and hit the barriers, with the W15 then ricocheting back onto the track and ending up on its side.

READ MORE: Alonso hit with post-race time penalty in Australia over ‘potentially dangerous’ driving before Russell crash

While Russell fortunately reported that he was unharmed in the incident, it ultimately ended what had been a tough day for the Silver Arrows following Lewis Hamilton’s earlier retirement due to a mechanical issue.

It was confirmed after the race that both Russell and Alonso had been summoned to the stewards over the incident, with the Aston Martin man hit with a 20-second penalty for what the stewards deemed was "potentially dangerous" driving. Speaking before the hearing, Alonso gave his version of events during a conversation on Sky Sports.

“Well, obviously I was focusing in front of me and not behind,” the Spaniard explained. “I had some issues for the last 15 laps, something on the battery on the deployment, so definitely I was struggling a little bit at the end of the race, but yeah, I cannot focus on the cars behind. But he’s okay apparently, I saw the car and I was very worried.”

### Question:
What according to the Spaniard was the cause of the accident?

### Answer:
"""

In [10]:
print(prompt)

### Context:
George Russell and Fernando Alonso have each offered their take on the incident that preceded Russell’s dramatic crash on the penultimate lap of the Australian Grand Prix following their battle for P6, with the Spaniard subsequently hit with a 20-second penalty after the race.

Russell had been chasing the Aston Martin for several laps following his final pit stop but, after getting close towards the Turn 6/7 complex, lost control of his Mercedes and hit the barriers, with the W15 then ricocheting back onto the track and ending up on its side.

READ MORE: Alonso hit with post-race time penalty in Australia over ‘potentially dangerous’ driving before Russell crash

While Russell fortunately reported that he was unharmed in the incident, it ultimately ended what had been a tough day for the Silver Arrows following Lewis Hamilton’s earlier retirement due to a mechanical issue.

It was confirmed after the race that both Russell and Alonso had been summoned to the stewards over

In [11]:
pipe = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    max_length=1024,
    return_full_text=False
)
result = pipe(f"{prompt}")
print(result[0]['generated_text'])

a mechanical issue
