In [46]:
!pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118


In [47]:
!pip install --no-cache-dir -q accelerate peft bitsandbytes transformers trl scipy

In [64]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    HfArgumentParser, TrainingArguments, pipeline, logging
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig

In [65]:
from datasets import load_dataset

squad_dataset = load_dataset("rajpurkar/squad_v2")
print(len(squad_dataset['train']))

130319


In [50]:
squad_dataset = squad_dataset['train'].shuffle(seed=42).sam
squad_dataset[0]

{'id': '56e0f3907aa994140058e80a',
 'title': 'Canon_law',
 'context': 'The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff:',
 'question': 'What term characterizes the intersection of the rites with the Roman Catholic Church?',
 'answers': {'text': ['full union'], 'answer_start': [104]}}

In [51]:
def transform_conversation(example):
    context = example['context']
    question = example['question']
    user_text = f"Context: {context} Question: {question}"
    answer = example['answers']['text'][0] if example['answers']['text'] else ""

    reformatted = f"<s>[INST] {user_text} [/INST] {answer} </s>"
    return {"text" : reformatted}

In [52]:
transformed_dataset = squad_dataset.map(transform_conversation)

In [53]:
transformed_dataset[0]

{'id': '56e0f3907aa994140058e80a',
 'title': 'Canon_law',
 'context': 'The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff:',
 'question': 'What term characterizes the intersection of the rites with the Roman Catholic Church?',
 'answers': {'text': ['full union'], 'answer_start': [104]},
 'text': '<s>[INST] Context: The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff: Question: What term characterizes the intersection of the rites with the Roman Catholic Church? [/INST] full union </s>'}

In [54]:
transformed_dataset[1]['text']

"<s>[INST] Context: Alexandria was the most important trade center in the whole empire during Athanasius's boyhood. Intellectually, morally, and politically—it epitomized the ethnically diverse Graeco-Roman world, even more than Rome or Constantinople, Antioch or Marseilles. Its famous catechetical school, while sacrificing none of its famous passion for orthodoxy since the days of Pantaenus, Clement of Alexandria, Origen of Alexandria, Dionysius and Theognostus, had begun to take on an almost secular character in the comprehensiveness of its interests, and had counted influential pagans among its serious auditors. Question: What was Alexandria known for? [/INST] important trade center </s>"

In [55]:
model_name = "NousResearch/Llama-2-7b-chat-hf"
dataset_name = "rajpurkar/squad_v2"
finetune_model = "Llama-2-7b-chat-finetune"


# Output folder
output_dir = "./results"

# No of epochs
num_train_epochs = 1

# No change params
use_4bit, bnb_4bit_compute_dtype, bnb_4bit_quant_type, use_nested_quant = True, "float16", "nf4", False # To quantization
lora_r, lora_alpha, lora_dropout = 64, 16, 0.1
fp16, bf16 =  False, False
per_device_train_batch_size, per_device_eval_batch_size = 1, 1
gradient_accumulation_steps, gradient_checkpointing, max_grad_norm = 1, True, 0.3
learning_rate, weight_decay, optim = 2e-4, 0.001, "paged_adamw_32bit"
lr_scheduler_type, max_steps, warmup_ratio = "cosine", -1, 0.03
group_by_length, save_steps, logging_steps = True, 0, 25
max_seq_length, packing, device_map = 1024, False, {"": 0}


In [56]:
# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

In [57]:
# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)
model.config.use_cache = False
model.config.pretraining_tp = 1

Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.96s/it]


In [58]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.model_max_length = max_seq_length
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

In [59]:
# Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)

In [60]:
import os
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

In [61]:
# Set training parameters

training_arguments = SFTConfig(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    max_length=max_seq_length,
    dataset_text_field="text",
    packing=packing,
)


In [62]:
# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=transformed_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [63]:
trainer.train()

Step,Training Loss
25,1.2834
50,1.4801
75,1.3167
100,1.4389
125,1.1503
150,1.3475
175,1.1266
200,1.223
225,1.0132
250,1.1463


KeyboardInterrupt: 

In [66]:
# Save trained model
trainer.model.save_pretrained(finetune_model)

In [67]:
# Reload and merge
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, finetune_model)
model = model.merge_and_unload()

# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.81s/it]


In [None]:
# save_path = "./llama-2-7b-chat-base.pt"

In [None]:
# base_model.save_pretrained(save_path)

: 

In [68]:
model.save_pretrained("models/finetune_model.pt")
tokenizer.save_pretrained("models/tokenizer/")

('models/tokenizer/tokenizer_config.json',
 'models/tokenizer/special_tokens_map.json',
 'models/tokenizer/tokenizer.json')

In [None]:
# Free VRAM
import gc
# del model, pipe, trainer
del trainer

gc.collect()

6397

In [None]:
tokenizer = AutoTokenizer.from_pretrained("models/tokenizer/")
model = AutoModelForCausalLM.from_pretrained(
    "models/finetune_model.pt",
    torch_dtype="auto",
    device_map="auto"
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.16s/it]
We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.


In [None]:
# Test the model with a sample question
val_dataset = load_dataset(dataset_name, split='validation')
val_dataset = val_dataset.select(range(100))
val_transformed_dataset = val_dataset.map(transform_conversation)
val_transformed_dataset = val_transformed_dataset.shuffle(seed=42)
val_transformed_dataset[2]

Map: 100%|██████████| 100/100 [00:00<00:00, 4640.79 examples/s]


{'id': '5ad3e96b604f3c001a3ff68a',
 'title': 'Normans',
 'context': 'Some Normans joined Turkish forces to aid in the destruction of the Armenians vassal-states of Sassoun and Taron in far eastern Anatolia. Later, many took up service with the Armenian state further south in Cilicia and the Taurus Mountains. A Norman named Oursel led a force of "Franks" into the upper Euphrates valley in northern Syria. From 1073 to 1074, 8,000 of the 20,000 troops of the Armenian general Philaretus Brachamius were Normans—formerly of Oursel—led by Raimbaud. They even lent their ethnicity to the name of their castle: Afranji, meaning "Franks." The known trade between Amalfi and Antioch and between Bari and Tarsus may be related to the presence of Italo-Normans in those cities while Amalfi and Bari were under Norman rule in Italy.',
 'question': 'Who did the Turks take up service with?',
 'answers': {'text': [], 'answer_start': []},
 'text': '<s>[INST] Context: Some Normans joined Turkish forces to aid 

In [None]:
# Create a pipeline for question answering
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    torch_dtype=torch.float16
)

sample = transformed_dataset[16]

sample_context = sample['context']
sample_question = sample['question']
input_text = f"Context: {sample_context} Question: {sample_question}"

output = pipe(f"[INST] {input_text} [/INST]")
output_text = output[0]['generated_text']
print(output_text)
print(f"Expected answer: {sample['answers']['text'][0] if sample['answers']['text'] else ''}")

Device set to use cuda:0


[INST] Context: Wood unsuitable for construction in its native form may be broken down mechanically (into fibers or chips) or chemically (into cellulose) and used as a raw material for other building materials, such as engineered wood, as well as chipboard, hardboard, and medium-density fiberboard (MDF). Such wood derivatives are widely used: wood fibers are an important component of most paper, and cellulose is used as a component of some synthetic materials. Wood derivatives can also be used for kinds of flooring, for example laminate flooring. Question: What material results from chemically breaking down wood? [/INST] cellulose
Expected answer: cellulose


In [None]:
from tqdm import tqdm
pbar = tqdm(enumerate(val_transformed_dataset), total=len(val_transformed_dataset), desc="Generating answers")
answers = {}

for i, sample in pbar:
    sample_context = sample['context']
    sample_question = sample['question']
    input_text = f"Context: {sample_context} Question: {sample_question}"
    output = pipe(f"[INST] {input_text} [/INST]")
    output_text = output[0]['generated_text']
    output_text = output_text.split("[/INST]")[-1].strip()
    answers[sample['id']] = output_text

Generating answers: 100%|██████████| 100/100 [00:12<00:00,  7.93it/s]


In [None]:
import json
from collections import defaultdict

def convert_hf_squad2_to_json(hf_dataset, version="v2.0"):
    """
    Converts a Hugging Face SQuAD v2.0 dataset (or subset) to official SQuAD 2.0 JSON format.
    Args:
        hf_dataset: HuggingFace Dataset (e.g., dataset["train"] or a sampled subset)
        output_path: Path to save the JSON file
        version: SQuAD version string (default "v2.0")
    """
    def format_answers(answers):
        # Converts {'text': [...], 'answer_start': [...]} to list of dicts
        return [
            {"text": t, "answer_start": s}
            for t, s in zip(answers["text"], answers["answer_start"])
        ] if answers and "text" in answers and "answer_start" in answers else []

    data_dict = defaultdict(lambda: defaultdict(list))

    for ex in hf_dataset:
        print(ex['id'])
        title = ex.get("title", "No Title")
        context = ex["context"]
        qas_entry = {
            "id": ex["id"],
            "question": ex["question"],
            "is_impossible": ex.get("is_impossible", False),
            "answers": format_answers(ex["answers"])
        }
        if ex.get("is_impossible", False):
            plausible_answers = ex.get("plausible_answers")
            if plausible_answers:
                qas_entry["plausible_answers"] = format_answers(plausible_answers)
            else:
                qas_entry["plausible_answers"] = []
        data_dict[title][context].append(qas_entry)

    data = []
    for title, paras in data_dict.items():
        paragraphs = []
        for context, qas_list in paras.items():
            paragraphs.append({
                "context": context,
                "qas": qas_list
            })
        data.append({
            "title": title,
            "paragraphs": paragraphs
        })

    squad_json = {
        "version": version,
        "data": data
    }
    return squad_json

In [None]:
val_squad_form = convert_hf_squad2_to_json(val_dataset)

56ddde6b9a695914005b9628
56ddde6b9a695914005b9629
56ddde6b9a695914005b962a
56ddde6b9a695914005b962b
56ddde6b9a695914005b962c
5ad39d53604f3c001a3fe8d1
5ad39d53604f3c001a3fe8d2
5ad39d53604f3c001a3fe8d3
5ad39d53604f3c001a3fe8d4
56dddf4066d3e219004dad5f
56dddf4066d3e219004dad60
56dddf4066d3e219004dad61
5ad3a266604f3c001a3fea27
5ad3a266604f3c001a3fea28
5ad3a266604f3c001a3fea29
5ad3a266604f3c001a3fea2a
5ad3a266604f3c001a3fea2b
56dde0379a695914005b9636
56dde0379a695914005b9637
5ad3ab70604f3c001a3feb89
5ad3ab70604f3c001a3feb8a
56dde0ba66d3e219004dad75
56dde0ba66d3e219004dad76
56dde0ba66d3e219004dad77
5ad3ad61604f3c001a3fec0d
5ad3ad61604f3c001a3fec0e
5ad3ad61604f3c001a3fec0f
5ad3ad61604f3c001a3fec10
56dde1d966d3e219004dad8d
5ad3ae14604f3c001a3fec39
5ad3ae14604f3c001a3fec3a
56dde27d9a695914005b9651
56dde27d9a695914005b9652
5ad3af11604f3c001a3fec63
5ad3af11604f3c001a3fec64
5ad3af11604f3c001a3fec65
56dde2fa66d3e219004dad9b
5ad3c626604f3c001a3ff011
5ad3c626604f3c001a3ff012
5ad3c626604f3c001a3ff013


In [None]:
import json

with open("data.json", "w") as f:
    json.dump(val_squad_form, f)

with open("pred.json", "w") as f:
    json.dump(answers, f)