# Please install the following libraries using pip (please create new environment as using these causes dependency conflicts often)

pip install unsloth

pip install transformers torch bitsandbytes

# Step 1: Answer extraction from Llama 3.2 1B:

Llama 3.2 1B model is not optimized for extractive QA, as evidenced by its success on simple questions but our dataset's diverse questions (e.g., dates, locations, names).

The model’s outputs are too verbose or off-topic for most questions, unlike QA-specific models (e.g., distilbert-base-uncased-distilled-squad).

Therefore, the below method tries to extract meaningful answers from the generated text from the Llama model 

In [None]:
#Imports
from google.colab import drive, userdata
from unsloth import FastLanguageModel
import torch
import re


# Google Drive Mount (we save the fine tuned model on drive later and access it from there)
drive.mount('/content/drive', force_remount=True)
drive_path = '/content/drive/MyDrive'


#We use Unsloth for efficient VRAM usage
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Llama-3.2-1B",
    token=userdata.get('HF_TOKEN'),
    max_seq_length=128,
    dtype=torch.float16,
    load_in_4bit=True
)
tokenizer.padding_side = 'left'
model.eval()



# Step 4: Define context and question
context = (
    "The Normans were the people who in the 10th and 11th centuries gave their name to Normandy, "
    "a region in France. They were descended from Norse raiders and pirates from Denmark, Iceland, "
    "and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia."
)
question = "In what country is Normandy located?"
#Expected answer: Rollo

#Create input prompt
input_text = (
    f"Based on the following context, answer the question in one word or a short phrase:\n\n"
    f"Context: {context}\n"
    f"Question: {question}\n"
    f"Answer: "
)



#Generate answer
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=128, padding=True).to(model.device)
with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=50,
        do_sample=True,
        top_k=40,
        temperature=0.6,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)



#Extract and print answer
try:
    pred_answer = pred_text.split("Answer:")[-1].strip() if "Answer:" in pred_text else pred_text.strip()
    pred_answer = re.sub(r'[^\x00-\x7F]+', '', pred_answer)
    lines = [
        line.strip() for line in pred_answer.split("\n")
        if line.strip()
        and not line.startswith("Context:")
        and not line.startswith("Question:")
        and not line.lower().startswith("explanation")
        and not line.lower().startswith("according")
        and not line.lower().startswith("based")
    ]
    pred_answer = lines[0] if lines else pred_answer
    name_match = re.search(r'\b[A-Z][a-z]*(?:\s[A-Z][a-z]*)?\b', pred_answer)
    if name_match:
        pred_answer = name_match.group(0)
    else:
        words = pred_answer.split()[:3]
        pred_answer = " ".join(words)
    if (pred_answer.isdigit() or
        len(pred_answer.split()) > 3 or
        not pred_answer or
        pred_answer.endswith("?") or
        pred_answer.lower() in ["who", "what", "when", "where", "why"]):
        pred_answer = ""
    pred_answer = pred_answer.split(".")[0].strip()
    pred_answer = pred_answer.split(",")[0].strip()
except Exception as e:
    print(f"Error extracting answer: {e}")
    pred_answer = ""

print(f"Generated Answer: {pred_answer}")

Mounted at /content/drive
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.48.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Raw Generated Text: Based on the following context, answer the question in one word or a short phrase:

Context: The Normans were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse raiders and pirates from Denmark, Iceland, and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia.
Question: In what country is Normandy located?
Answer:  France
Explanation:  Normandy is a region in Fra

# Step2: Evalutation of Pretrained Llama Model for baseline scores

In [None]:
#Imports
from google.colab import drive, userdata
from unsloth import FastLanguageModel
from datasets import load_dataset
from evaluate import load
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import random
import re
import nltk
nltk.download('wordnet')
nltk.download('punkt_tab')
nltk.download('omw-1.4')


#Load sQuad Dataset
print("Loading SQuAD v2 dataset...")
validation_dataset = load_dataset("rajpurkar/squad_v2", split="validation[:50]")
validation_dataset = validation_dataset.select(range(5))

def preprocess_squad(batch):
    try:
        context = batch["context"]
        question = batch["question"]
        answers = batch["answers"]["text"]
        answer = answers[0] if answers else ""
        if not all(isinstance(x, str) for x in [context, question, answer]):
            return {"input_text": None, "output_text": None, "context": None}
        input_text = (
            f"Based on the following context, answer the question in one word or a short phrase:\n\n"
            f"Context: {context}\n"
            f"Question: {question}\n"
            f"Answer: "
        )
        return {"input_text": input_text, "output_text": answer, "context": context}
    except Exception as e:
        print(f"Error preprocessing sample: {e}")
        return {"input_text": None, "output_text": None, "context": None}

validation_dataset = validation_dataset.map(preprocess_squad, remove_columns=["id", "title", "context", "question", "answers"])
validation_dataset = validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)
print(f"Validation dataset size: {len(validation_dataset)}")
for i, sample in enumerate(validation_dataset):
    question = sample['input_text'].split("Question: ")[1].split("\nAnswer: ")[0].strip()
    print(f"Validation sample {i}: question={question}, output_text={sample['output_text']}")



#OOD Analysis
ood_validation_dataset = load_dataset("rajpurkar/squad_v2", split="validation[:50]")
ood_validation_dataset = ood_validation_dataset.select(range(5))
ood_validation_dataset = ood_validation_dataset.map(preprocess_squad, remove_columns=["id", "title", "context", "question", "answers"])
ood_validation_dataset = ood_validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)


def preprocess_ood(batch):
    try:
        input_text = batch["input_text"]
        output_text = batch["output_text"]
        context = batch["context"]
        if not input_text or not isinstance(input_text, str):
            return {"input_text": None, "output_text": None}
        question = input_text.split("Question: ", 1)[1].split("\nAnswer: ", 1)[0].strip()
        sentences = context.split(". ")
        random.shuffle(sentences)
        perturbed_context = ". ".join(sentences)
        return {
            "input_text": (
                f"Based on the following context, answer the question in one word or a short phrase:\n\n"
                f"Context: {perturbed_context}\n"
                f"Question: {question}\n"
                f"Answer: "
            ),
            "output_text": output_text
        }
    except Exception as e:
        print(f"Error in OOD preprocessing: {e}")
        return {"input_text": None, "output_text": None}

ood_validation_dataset = ood_validation_dataset.map(preprocess_ood)
ood_validation_dataset = ood_validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)
print(f"OOD validation dataset size: {len(ood_validation_dataset)}")
for i, sample in enumerate(ood_validation_dataset):
    question = sample['input_text'].split("Question: ")[1].split("\nAnswer: ")[0].strip()
    print(f"OOD sample {i}: question={question}, output_text={sample['output_text']}")



#Load Llama with unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Llama-3.2-1B",
    token=userdata.get('HF_TOKEN'),
    max_seq_length=256,
    dtype=torch.float16,
    load_in_4bit=True
)
tokenizer.padding_side = 'left'
model.eval()

def generate_answer(input_text, context):
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=256, padding=True).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=100,
            do_sample=False,
            top_k=40,
            temperature=0.6,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    try:
        pred_answer = pred_text.split("Answer:")[-1].strip() if "Answer:" in pred_text else pred_text.strip()
        pred_answer = re.sub(r'[^\x00-\x7F]+', '', pred_answer)
        lines = [
            line.strip() for line in pred_answer.split("\n")
            if line.strip()
            and not line.startswith("Context:")
            and not line.startswith("Question:")
            and not line.lower().startswith(("explanation", "according", "based"))
        ]
        pred_answer = lines[0] if lines else pred_answer
        words = [w for w in pred_answer.split() if w.lower() in context.lower()][:5]
        pred_answer = " ".join(words)
        if (pred_answer.isdigit() or
            not pred_answer or
            pred_answer.endswith("?") or
            pred_answer.lower() in ["who", "what", "when", "where", "why", "the", "a", "an"]):
            pred_answer = ""
        pred_answer = pred_answer.split(".")[0].strip()
        pred_answer = pred_answer.split(",")[0].strip()
    except Exception as e:
        print(f"Error extracting answer: {e}, pred_text={pred_text[:50]}...")
        pred_answer = ""
    return pred_text, pred_answer

llama_val_predictions = []
llama_val_raw_predictions = []
llama_ood_predictions = []
llama_ood_raw_predictions = []
val_labels = []
ood_labels = []

print("Generating answers with Llama...")
for i, sample in enumerate(validation_dataset):
    input_text = sample["input_text"]
    context = sample["context"]
    label = sample["output_text"]
    raw_pred, pred = generate_answer(input_text, context)
    llama_val_raw_predictions.append(raw_pred)
    llama_val_predictions.append(pred)
    val_labels.append(label)

for i, sample in enumerate(ood_validation_dataset):
    input_text = sample["input_text"]
    context = input_text.split("\nContext: ", 1)[1].split("\nQuestion: ", 1)[0].strip()
    label = sample["output_text"]
    raw_pred, pred = generate_answer(input_text, context)
    llama_ood_raw_predictions.append(raw_pred)
    llama_ood_predictions.append(pred)
    ood_labels.append(label)

exact_match_metric = load("exact_match", trust_remote_code=True)
squad_metric = load("squad", trust_remote_code=True)
rouge_scorer_instance = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

def compute_perplexity(logits, labels):
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return torch.exp(loss).item()

def evaluate_predictions(predictions, labels, dataset, dataset_name, raw_predictions):
    pred_answers = predictions
    label_answers = labels
    context_lengths = []
    perplexities = []
    bleu_scores = []
    rouge1_scores = []
    rougeL_scores = []
    meteor_scores = []

    for i, (pred, label, sample) in enumerate(zip(pred_answers, label_answers, dataset)):
        context = sample["context"] if dataset_name == "Validation" else sample["input_text"].split("\nContext: ", 1)[1].split("\nQuestion: ", 1)[0].strip()
        try:
            words = [w for w in context.split() if w]
            context_lengths.append(len(words))
        except Exception as e:
            print(f"Error computing context length: {e}")
            context_lengths.append(0)

        inputs = tokenizer(sample["input_text"], return_tensors="pt", truncation=True, max_length=256, padding=True).to(model.device)
        with torch.no_grad():
            outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=inputs["input_ids"])
            perplexity = compute_perplexity(outputs.logits, inputs["input_ids"])
            perplexities.append(perplexity)

        if label:
            bleu_score = sentence_bleu([label.split()], pred.split() if pred else [""], weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1)
            rouge_scores = rouge_scorer_instance.score(label, pred if pred else "")
            rouge1_f = rouge_scores['rouge1'].fmeasure
            rougeL_f = rouge_scores['rougeL'].fmeasure
            meteor_score = load("meteor", trust_remote_code=True).compute(predictions=[pred if pred else ""], references=[label])["meteor"]
        else:
            bleu_score = 0.0
            rouge1_f = 0.0
            rougeL_f = 0.0
            meteor_score = 0.0

        bleu_scores.append(bleu_score)
        rouge1_scores.append(rouge1_f)
        rougeL_scores.append(rougeL_f)
        meteor_scores.append(meteor_score)

    em_score = exact_match_metric.compute(predictions=pred_answers, references=label_answers)["exact_match"]
    squad_results = squad_metric.compute(predictions=[{"id": str(i), "prediction_text": pred} for i, pred in enumerate(pred_answers)],
                                        references=[{"id": str(i), "answers": {"text": [ref], "answer_start": [0]}} for i, ref in enumerate(label_answers)])
    f1_score = squad_results["f1"]
    precision, recall, _, _ = precision_recall_fscore_support(label_answers, pred_answers, average='weighted', zero_division=0)
    unique_answers = list(set(label_answers + pred_answers))[:10]
    cm = confusion_matrix(label_answers, pred_answers, labels=unique_answers) if unique_answers else np.array([[len(pred_answers)]])
    mean_context_length = np.mean(context_lengths) if context_lengths else 0
    mean_perplexity = np.mean(perplexities) if perplexities else 0
    mean_bleu = np.mean(bleu_scores)
    mean_rouge1 = np.mean(rouge1_scores)
    mean_rougeL = np.mean(rougeL_scores)
    mean_meteor = np.mean(meteor_scores)

    print(f"\n{dataset_name} Evaluation Results (Llama):")
    print(f"- Exact Match (EM): {em_score:.4f}")
    print(f"- F1 Score: {f1_score:.4f}")
    print(f"- Precision: {precision:.4f}")
    print(f"- Recall: {recall:.4f}")
    print(f"- BLEU Score: {mean_bleu:.4f}")
    print(f"- ROUGE-1 F1: {mean_rouge1:.4f}")
    print(f"- ROUGE-L F1: {mean_rougeL:.4f}")
    print(f"- METEOR Score: {mean_meteor:.4f}")
    print(f"- Mean Perplexity: {mean_perplexity:.4f}")
    print(f"- Mean Context Length: {mean_context_length:.2f}")
    print(f"Sample Raw Predictions: {raw_predictions}")
    print(f"Sample Predictions: {pred_answers}")
    print(f"Sample Labels: {label_answers}")

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=unique_answers, yticklabels=unique_answers)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({dataset_name} - Llama)")
    plt.savefig(f"{drive_path}/confusion_matrix_{dataset_name.lower().replace(' ', '_')}_llama.png")
    plt.close()

    return {
        "exact_match": em_score,
        "f1": f1_score,
        "precision": precision,
        "recall": recall,
        "bleu": mean_bleu,
        "rouge1": mean_rouge1,
        "rougeL": mean_rougeL,
        "meteor": mean_meteor,
        "perplexity": mean_perplexity,
        "mean_context_length": mean_context_length,
        "context_lengths": context_lengths,
        "em_by_length": [1 if pred == ref else 0 for pred, ref in zip(pred_answers, label_answers)]
    }

#Evaluation
print("Evaluating Llama on validation set...")
val_results_llama = evaluate_predictions(llama_val_predictions, val_labels, validation_dataset, "Validation", llama_val_raw_predictions)
print("Evaluating Llama on OOD set...")
ood_results_llama = evaluate_predictions(llama_ood_predictions, ood_labels, ood_validation_dataset, "OOD", llama_ood_raw_predictions)
with open(f"{drive_path}/evaluation_results.txt", "w") as f:
    f.write("Llama Validation Results:\n")
    f.write(str(val_results_llama) + "\n\n")
    f.write("Llama OOD Validation Results:\n")
    f.write(str(ood_results_llama) + "\n\n")
    f.write("Comparative Analysis (Llama Validation vs. OOD):\n")
    f.write(f"- EM: Validation: {val_results_llama['exact_match']:.4f} vs. OOD: {ood_results_llama['exact_match']:.4f}\n")
    f.write(f"- F1: Validation: {val_results_llama['f1']:.4f} vs. OOD: {ood_results_llama['f1']:.4f}\n")
    f.write(f"- Precision: Validation: {val_results_llama['precision']:.4f} vs. OOD: {ood_results_llama['precision']:.4f}\n")
    f.write(f"- Recall: Validation: {val_results_llama['recall']:.4f} vs. OOD: {ood_results_llama['recall']:.4f}\n")
    f.write(f"- BLEU: Validation: {val_results_llama['bleu']:.4f} vs. OOD: {ood_results_llama['bleu']:.4f}\n")
    f.write(f"- ROUGE-1: Validation: {val_results_llama['rouge1']:.4f} vs. OOD: {ood_results_llama['rouge1']:.4f}\n")
    f.write(f"- ROUGE-L: Validation: {val_results_llama['rougeL']:.4f} vs. OOD: {ood_results_llama['rougeL']:.4f}\n")
    f.write(f"- METEOR: Validation: {val_results_llama['meteor']:.4f} vs. OOD: {ood_results_llama['meteor']:.4f}\n")
    f.write(f"- Perplexity: Validation: {val_results_llama['perplexity']:.4f} vs. OOD: {ood_results_llama['perplexity']:.4f}\n")
    f.write(f"- Mean Context Length: Validation: {val_results_llama['mean_context_length']:.2f} vs. OOD: {ood_results_llama['mean_context_length']:.2f}\n")

Mounted at /content/drive
Loading SQuAD v2 dataset...


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Validation dataset size: 5
Validation sample 0: question=In what country is Normandy located?, output_text=France
Validation sample 1: question=When were the Normans in Normandy?, output_text=10th and 11th centuries
Validation sample 2: question=From which countries did the Norse originate?, output_text=Denmark, Iceland and Norway
Validation sample 3: question=Who was the Norse leader?, output_text=Rollo
Validation sample 4: question=What century did the Normans first gain their separate identity?, output_text=10th century
OOD validation dataset size: 5
OOD sample 0: question=In what country is Normandy located?, output_text=France
OOD sample 1: question=When were the Normans in Normandy?, output_text=10th and 11th centuries
OOD sample 2: question=From which countries did the Norse originate?, output_text=Denmark, Iceland and Norway
OOD sample 3: question=Who was the Norse leader?, output_text=Rollo
OOD sample 4: question=What century did the Normans first gain their separate identity?

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_


Validation Evaluation Results (Llama):
- Exact Match (EM): 0.6000
- F1 Score: 68.0000
- Precision: 0.6000
- Recall: 0.6000
- BLEU Score: 0.1344
- ROUGE-1 F1: 0.6800
- ROUGE-L F1: 0.6800
- METEOR Score: 0.4092
- Mean Perplexity: 8.9766
- Mean Context Length: 113.00
Sample Raw Predictions: ['Based on the following context, answer the question in one word or a short phrase:\n\nContext: The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity 

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_


OOD Evaluation Results (Llama):
- Exact Match (EM): 0.4000
- F1 Score: 56.0000
- Precision: 0.4000
- Recall: 0.4000
- BLEU Score: 0.1149
- ROUGE-1 F1: 0.5600
- ROUGE-L F1: 0.5600
- METEOR Score: 0.3862
- Mean Perplexity: 10.5531
- Mean Context Length: 113.00
Sample Raw Predictions: ['Based on the following context, answer the question in one word or a short phrase:\n\nContext: They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.. Through generations of assimilation and mixing with the native Frankish

# Model Training

In [None]:
#Imports
from google.colab import drive, userdata
from unsloth import FastLanguageModel
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
from evaluate import load
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import random
import re


train_dataset = load_dataset("rajpurkar/squad_v2", split="train[:5000]")
validation_dataset = load_dataset("rajpurkar/squad_v2", split="validation[:50]")
validation_dataset = validation_dataset.select(range(5))

def preprocess_squad(batch):
    try:
        context = batch["context"]
        question = batch["question"]
        answers = batch["answers"]["text"]
        answer = answers[0] if answers else "" 
        if not all(isinstance(x, str) for x in [context, question, answer]):
            return {"input_text": None, "output_text": None, "context": None}
        input_text = (
            f"Based on the following context, answer the question in one word or a short phrase:\n\n"
            f"Context: {context}\n"
            f"Question: {question}\n"
            f"Answer: {answer}"
        )
        return {"input_text": input_text, "output_text": answer, "context": context}
    except Exception as e:
        print(f"Error preprocessing sample: {e}")
        return {"input_text": None, "output_text": None, "context": None}

train_dataset = train_dataset.map(preprocess_squad, remove_columns=["id", "title", "context", "question", "answers"])
train_dataset = train_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)
validation_dataset = validation_dataset.map(preprocess_squad, remove_columns=["id", "title", "context", "question", "answers"])
validation_dataset = validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)


print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(validation_dataset)}")
for i, sample in enumerate(validation_dataset):
    question = sample['input_text'].split("Question: ")[1].split("\nAnswer: ")[0].strip()
    print(f"Validation sample {i}: question={question}, output_text={sample['output_text']}")
ood_validation_dataset = load_dataset("rajpurkar/squad_v2", split="validation[:50]")
ood_validation_dataset = ood_validation_dataset.select(range(5))
ood_validation_dataset = ood_validation_dataset.map(preprocess_squad, remove_columns=["id", "title", "context", "question", "answers"])
ood_validation_dataset = ood_validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)
def preprocess_ood(batch):
    try:
        input_text = batch["input_text"]
        output_text = batch["output_text"]
        context = batch["context"]
        if not input_text or not isinstance(input_text, str):
            return {"input_text": None, "output_text": None}
        question = input_text.split("Question: ", 1)[1].split("\nAnswer: ", 1)[0].strip()
        sentences = context.split(". ")
        random.shuffle(sentences)
        perturbed_context = ". ".join(sentences)
        return {
            "input_text": (
                f"Based on the following context, answer the question in one word or a short phrase:\n\n"
                f"Context: {perturbed_context}\n"
                f"Question: {question}\n"
                f"Answer: "
            ),
            "output_text": output_text
        }
    except Exception as e:
        print(f"Error in OOD preprocessing: {e}")
        return {"input_text": None, "output_text": None}

ood_validation_dataset = ood_validation_dataset.map(preprocess_ood)
ood_validation_dataset = ood_validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)
print(f"OOD validation dataset size: {len(ood_validation_dataset)}")
for i, sample in enumerate(ood_validation_dataset):
    question = sample['input_text'].split("Question: ")[1].split("\nAnswer: ")[0].strip()
    print(f"OOD sample {i}: question={question}, output_text={sample['output_text']}")


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Llama-3.2-1B",
    token=userdata.get('HF_TOKEN'),
    max_seq_length=256,
    dtype=torch.float16,
    load_in_4bit=True
)
tokenizer.padding_side = 'left'

#Use LoRA
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "v_proj"],
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none",
    use_gradient_checkpointing=True
)
model.print_trainable_parameters()
def tokenize_dataset(batch):
    try:
        input_tokenized = tokenizer(batch["input_text"], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
        return {
            "input_ids": input_tokenized["input_ids"][0],
            "attention_mask": input_tokenized["attention_mask"][0],
            "labels": input_tokenized["input_ids"][0]
        }
    except Exception as e:
        print(f"Error tokenizing sample: {e}")
        return {"input_ids": None, "attention_mask": None, "labels": None}

train_dataset = train_dataset.map(tokenize_dataset, remove_columns=["input_text", "output_text", "context"])
train_dataset = train_dataset.filter(lambda x: x["input_ids"] is not None and x["labels"] is not None)
validation_dataset_tokenized = validation_dataset.map(tokenize_dataset, remove_columns=["input_text", "output_text", "context"])
validation_dataset_tokenized = validation_dataset_tokenized.filter(lambda x: x["input_ids"] is not None and x["labels"] is not None)
ood_validation_dataset_tokenized = ood_validation_dataset.map(tokenize_dataset, remove_columns=["input_text", "output_text"])
ood_validation_dataset_tokenized = ood_validation_dataset_tokenized.filter(lambda x: x["input_ids"] is not None and x["labels"] is not None)

# Step 7: Fine-tune model
training_args = TrainingArguments(
    output_dir=f"{drive_path}/llama-3.2-1b-finetuned-5000",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_steps=50,
    save_steps=500,
    logging_steps=50,
    max_grad_norm=1.0,
    fp16=True,
    save_strategy="steps",
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

print("Fine-tuning model...")
torch.cuda.empty_cache()
trainer.train()
torch.cuda.empty_cache()

# Save model
trainer.save_model(f"{drive_path}/llama-3.2-1b-finetuned-5000-final")
tokenizer.save_pretrained(f"{drive_path}/llama-3.2-1b-finetuned-5000-final")


Mounted at /content/drive
Loading SQuAD v2 dataset...


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Train dataset size: 5000
Validation dataset size: 5
Validation sample 0: question=In what country is Normandy located?, output_text=France
Validation sample 1: question=When were the Normans in Normandy?, output_text=10th and 11th centuries
Validation sample 2: question=From which countries did the Norse originate?, output_text=Denmark, Iceland and Norway
Validation sample 3: question=Who was the Norse leader?, output_text=Rollo
Validation sample 4: question=What century did the Normans first gain their separate identity?, output_text=10th century
OOD validation dataset size: 5
OOD sample 0: question=In what country is Normandy located?, output_text=France
OOD sample 1: question=When were the Normans in Normandy?, output_text=10th and 11th centuries
OOD sample 2: question=From which countries did the Norse originate?, output_text=Denmark, Iceland and Norway
OOD sample 3: question=Who was the Norse leader?, output_text=Rollo
OOD sample 4: question=What century did the Normans first gain

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5 [00:00<?, ? examples/s]

Fine-tuning model...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 5,000 | Num Epochs = 3 | Total steps = 936
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 1,703,936/1,000,000,000 (0.17% trained)


Step,Training Loss
50,4.3251
100,3.8571
150,3.4828
200,3.3309
250,3.3169
300,3.2036
350,3.2451
400,3.1779
450,3.2278
500,3.1985


('/content/drive/MyDrive/llama-3.2-1b-finetuned-5000-final/tokenizer_config.json',
 '/content/drive/MyDrive/llama-3.2-1b-finetuned-5000-final/special_tokens_map.json',
 '/content/drive/MyDrive/llama-3.2-1b-finetuned-5000-final/tokenizer.json')

In [None]:
def preprocess_squad(batch):
    try:
        context = batch["context"]
        question = batch["question"]
        answers = batch["answers"]["text"]
        answer = answers[0] if answers else ""
        if not all(isinstance(x, str) for x in [context, question, answer]):
            return {"input_text": None, "output_text": None, "context": None}
        input_text = (
            f"Based on the following context, answer the question in one word or a short phrase:\n\n"
            f"Context: {context}\n"
            f"Question: {question}\n"
            f"Answer: "
        )
        return {"input_text": input_text, "output_text": answer, "context": context}
    except Exception as e:
        print(f"Error preprocessing sample: {e}")
        return {"input_text": None, "output_text": None, "context": None}

validation_dataset = validation_dataset.map(preprocess_squad, remove_columns=["id", "title", "context", "question", "answers"])
validation_dataset = validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)
ood_validation_dataset = load_dataset("rajpurkar/squad_v2", split="validation[:50]")
ood_validation_dataset = ood_validation_dataset.select(range(5))
ood_validation_dataset = ood_validation_dataset.map(preprocess_squad, remove_columns=["id", "title", "context", "question", "answers"])
ood_validation_dataset = ood_validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)

def preprocess_ood(batch):
    try:
        input_text = batch["input_text"]
        output_text = batch["output_text"]
        context = batch["context"]
        if not input_text or not isinstance(input_text, str):
            return {"input_text": None, "output_text": None}
        question = input_text.split("Question: ", 1)[1].split("\nAnswer: ", 1)[0].strip()
        sentences = context.split(". ")
        random.shuffle(sentences)
        perturbed_context = ". ".join(sentences)
        return {
            "input_text": (
                f"Based on the following context, answer the question in one word or a short phrase:\n\n"
                f"Context: {perturbed_context}\n"
                f"Question: {question}\n"
                f"Answer: "
            ),
            "output_text": output_text
        }
    except Exception as e:
        print(f"Error in OOD preprocessing: {e}")
        return {"input_text": None, "output_text": None}

ood_validation_dataset = ood_validation_dataset.map(preprocess_ood)
ood_validation_dataset = ood_validation_dataset.filter(lambda x: x["input_text"] is not None and x["output_text"] is not None)

#Loading our fine tuned model from the previous step
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=f"{drive_path}/llama-3.2-1b-finetuned-5000-final",
    token=userdata.get('HF_TOKEN'),
    max_seq_length=256,
    dtype=torch.float16,
    load_in_4bit=True
)
tokenizer.padding_side = 'left'
model.eval()
def generate_answer(input_text, context, model, tokenizer):
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=256, padding=True).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=100,
            do_sample=False,
            top_k=40,
            temperature=0.6,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    try:
        pred_answer = pred_text.split("Answer:")[-1].strip() if "Answer:" in pred_text else pred_text.strip()
        pred_answer = re.sub(r'[^\x00-\x7F]+', '', pred_answer)
        lines = [
            line.strip() for line in pred_answer.split("\n")
            if line.strip()
            and not line.startswith("Context:")
            and not line.startswith("Question:")
            and not line.lower().startswith(("explanation", "according", "based"))
        ]
        pred_answer = lines[0] if lines else pred_answer
        words = [w for w in pred_answer.split() if w.lower() in context.lower()][:5]
        pred_answer = " ".join(words)
        if (pred_answer.isdigit() or
            not pred_answer or
            pred_answer.endswith("?") or
            pred_answer.lower() in ["who", "what", "when", "where", "why", "the", "a", "an"]):
            pred_answer = ""
        pred_answer = pred_answer.split(".")[0].strip()
        pred_answer = pred_answer.split(",")[0].strip()
    except Exception as e:
        print(f"Error extracting answer: {e}, pred_text={pred_text[:50]}...")
        pred_answer = ""
    return pred_text, pred_answer

finetuned_val_predictions = []
finetuned_val_raw_predictions = []
finetuned_ood_predictions = []
finetuned_ood_raw_predictions = []
val_labels = []
ood_labels = []

print("Generating fine-tuned answers...")
for i, sample in enumerate(validation_dataset):
    input_text = sample["input_text"]
    context = sample["context"]
    label = sample["output_text"]
    raw_pred, pred = generate_answer(input_text, context, model, tokenizer)
    finetuned_val_raw_predictions.append(raw_pred)
    finetuned_val_predictions.append(pred)
    val_labels.append(label)

for i, sample in enumerate(ood_validation_dataset):
    input_text = sample["input_text"]
    context = input_text.split("\nContext: ", 1)[1].split("\nQuestion: ", 1)[0].strip()
    label = sample["output_text"]
    raw_pred, pred = generate_answer(input_text, context, model, tokenizer)
    finetuned_ood_raw_predictions.append(raw_pred)
    finetuned_ood_predictions.append(pred)
    ood_labels.append(label)

print("Evaluating pre-trained model...")
pretrained_model, pretrained_tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Llama-3.2-1B",
    token=userdata.get('HF_TOKEN'),
    max_seq_length=256,
    dtype=torch.float16,
    load_in_4bit=True
)
pretrained_tokenizer.padding_side = 'left'
pretrained_model.eval()

pretrained_val_predictions = []
pretrained_val_raw_predictions = []
pretrained_ood_predictions = []
pretrained_ood_raw_predictions = []

print("Generating pre-trained answers...")
for i, sample in enumerate(validation_dataset):
    input_text = sample["input_text"]
    context = sample["context"]
    raw_pred, pred = generate_answer(input_text, context, pretrained_model, pretrained_tokenizer)
    pretrained_val_raw_predictions.append(raw_pred)
    pretrained_val_predictions.append(pred)

for i, sample in enumerate(ood_validation_dataset):
    input_text = sample["input_text"]
    context = input_text.split("\nContext: ", 1)[1].split("\nQuestion: ", 1)[0].strip()
    raw_pred, pred = generate_answer(input_text, context, pretrained_model, pretrained_tokenizer)
    pretrained_ood_raw_predictions.append(raw_pred)
    pretrained_ood_predictions.append(pred)




#Computing the final evaluation metrics
exact_match_metric = load("exact_match", trust_remote_code=True)
squad_metric = load("squad", trust_remote_code=True)
rouge_scorer_instance = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

def compute_perplexity(logits, labels):
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return torch.exp(loss).item()

def evaluate_predictions(predictions, labels, dataset, dataset_name, raw_predictions, model, tokenizer):
    pred_answers = predictions
    label_answers = labels
    context_lengths = []
    perplexities = []
    bleu_scores = []
    rouge1_scores = []
    rougeL_scores = []
    meteor_scores = []

    for i, (pred, label, sample) in enumerate(zip(pred_answers, label_answers, dataset)):
        context = sample["context"] if "Validation" in dataset_name else sample["input_text"].split("\nContext: ", 1)[1].split("\nQuestion: ", 1)[0].strip()
        try:
            words = [w for w in context.split() if w]
            context_lengths.append(len(words))
        except Exception as e:
            print(f"Error computing context length: {e}")
            context_lengths.append(0)

        inputs = tokenizer(sample["input_text"], return_tensors="pt", truncation=True, max_length=256, padding=True).to(model.device)
        with torch.no_grad():
            outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=inputs["input_ids"])
            perplexity = compute_perplexity(outputs.logits, inputs["input_ids"])
            perplexities.append(perplexity)

        if label:
            bleu_score = sentence_bleu([label.split()], pred.split() if pred else [""], weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1)
            rouge_scores = rouge_scorer_instance.score(label, pred if pred else "")
            rouge1_f = rouge_scores['rouge1'].fmeasure
            rougeL_f = rouge_scores['rougeL'].fmeasure
            meteor_score = load("meteor", trust_remote_code=True).compute(predictions=[pred if pred else ""], references=[label])["meteor"]
        else:
            bleu_score = 0.0
            rouge1_f = 0.0
            rougeL_f = 0.0
            meteor_score = 0.0

        bleu_scores.append(bleu_score)
        rouge1_scores.append(rouge1_f)
        rougeL_scores.append(rougeL_f)
        meteor_scores.append(meteor_score)

    em_score = exact_match_metric.compute(predictions=pred_answers, references=label_answers)["exact_match"]
    squad_results = squad_metric.compute(predictions=[{"id": str(i), "prediction_text": pred} for i, pred in enumerate(pred_answers)],
                                        references=[{"id": str(i), "answers": {"text": [ref], "answer_start": [0]}} for i, ref in enumerate(label_answers)])
    f1_score = squad_results["f1"]
    precision, recall, _, _ = precision_recall_fscore_support(label_answers, pred_answers, average='weighted', zero_division=0)
    unique_answers = list(set(label_answers + pred_answers))[:10]
    cm = confusion_matrix(label_answers, pred_answers, labels=unique_answers) if unique_answers else np.array([[len(pred_answers)]])
    mean_context_length = np.mean(context_lengths) if context_lengths else 0
    mean_perplexity = np.mean(perplexities) if perplexities else 0
    mean_bleu = np.mean(bleu_scores)
    mean_rouge1 = np.mean(rouge1_scores)
    mean_rougeL = np.mean(rougeL_scores)
    mean_meteor = np.mean(meteor_scores)

    print(f"\n{dataset_name} Evaluation Results:")
    print(f"- Exact Match (EM): {em_score:.4f}")
    print(f"- F1 Score: {f1_score:.4f}")
    print(f"- Precision: {precision:.4f}")
    print(f"- Recall: {recall:.4f}")
    print(f"- BLEU Score: {mean_bleu:.4f}")
    print(f"- ROUGE-1 F1: {mean_rouge1:.4f}")
    print(f"- ROUGE-L F1: {mean_rougeL:.4f}")
    print(f"- METEOR Score: {mean_meteor:.4f}")
    print(f"- Mean Perplexity: {mean_perplexity:.4f}")
    print(f"- Mean Context Length: {mean_context_length:.2f}")
    print(f"Sample Raw Predictions: {raw_predictions}")
    print(f"Sample Predictions: {pred_answers}")
    print(f"Sample Labels: {label_answers}")

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=unique_answers, yticklabels=unique_answers)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({dataset_name})")
    plt.savefig(f"{drive_path}/confusion_matrix_{dataset_name.lower().replace(' ', '_')}.png")
    plt.close()

    return {
        "exact_match": em_score,
        "f1": f1_score,
        "precision": precision,
        "recall": recall,
        "bleu": mean_bleu,
        "rouge1": mean_rouge1,
        "rougeL": mean_rougeL,
        "meteor": mean_meteor,
        "perplexity": mean_perplexity,
        "mean_context_length": mean_context_length,
        "context_lengths": context_lengths,
        "em_by_length": [1 if pred == ref else 0 for pred, ref in zip(pred_answers, label_answers)],
        "predictions": pred_answers,
        "raw_predictions": raw_predictions
    }

#Comparative Analysis
print("Evaluating pre-trained model on validation set...")
val_results_pretrained = evaluate_predictions(pretrained_val_predictions, val_labels, validation_dataset, "Validation (Pre-trained)", pretrained_val_raw_predictions, pretrained_model, pretrained_tokenizer)
print("Evaluating pre-trained model on OOD set...")
ood_results_pretrained = evaluate_predictions(pretrained_ood_predictions, ood_labels, ood_validation_dataset, "OOD (Pre-trained)", pretrained_ood_raw_predictions, pretrained_model, pretrained_tokenizer)
print("Evaluating fine-tuned model on validation set...")
val_results_finetuned = evaluate_predictions(finetuned_val_predictions, val_labels, validation_dataset, "Validation (Fine-tuned)", finetuned_val_raw_predictions, model, tokenizer)
print("Evaluating fine-tuned model on OOD set...")
ood_results_finetuned = evaluate_predictions(finetuned_ood_predictions, ood_labels, ood_validation_dataset, "OOD (Fine-tuned)", finetuned_ood_raw_predictions, model, tokenizer)
print("\nComparative Analysis (Pre-trained vs. Fine-tuned):")
print("\nValidation Set:")
print(f"- EM: Pre-trained: {val_results_pretrained['exact_match']:.4f} vs. Fine-tuned: {val_results_finetuned['exact_match']:.4f} (Difference: {(val_results_finetuned['exact_match'] - val_results_pretrained['exact_match']):.4f})")
print(f"- F1: Pre-trained: {val_results_pretrained['f1']:.4f} vs. Fine-tuned: {val_results_finetuned['f1']:.4f} (Difference: {(val_results_finetuned['f1'] - val_results_pretrained['f1']):.4f})")
print(f"- Precision: Pre-trained: {val_results_pretrained['precision']:.4f} vs. Fine-tuned: {val_results_finetuned['precision']:.4f} (Difference: {(val_results_finetuned['precision'] - val_results_pretrained['precision']):.4f})")
print(f"- Recall: Pre-trained: {val_results_pretrained['recall']:.4f} vs. Fine-tuned: {val_results_finetuned['recall']:.4f} (Difference: {(val_results_finetuned['recall'] - val_results_pretrained['recall']):.4f})")
print(f"- BLEU: Pre-trained: {val_results_pretrained['bleu']:.4f} vs. Fine-tuned: {val_results_finetuned['bleu']:.4f} (Difference: {(val_results_finetuned['bleu'] - val_results_pretrained['bleu']):.4f})")
print(f"- ROUGE-1: Pre-trained: {val_results_pretrained['rouge1']:.4f} vs. Fine-tuned: {val_results_finetuned['rouge1']:.4f} (Difference: {(val_results_finetuned['rouge1'] - val_results_pretrained['rouge1']):.4f})")
print(f"- ROUGE-L: Pre-trained: {val_results_pretrained['rougeL']:.4f} vs. Fine-tuned: {val_results_finetuned['rougeL']:.4f} (Difference: {(val_results_finetuned['rougeL'] - val_results_pretrained['rougeL']):.4f})")
print(f"- METEOR: Pre-trained: {val_results_pretrained['meteor']:.4f} vs. Fine-tuned: {val_results_finetuned['meteor']:.4f} (Difference: {(val_results_finetuned['meteor'] - val_results_pretrained['meteor']):.4f})")
print(f"- Perplexity: Pre-trained: {val_results_pretrained['perplexity']:.4f} vs. Fine-tuned: {val_results_finetuned['perplexity']:.4f} (Difference: {(val_results_finetuned['perplexity'] - val_results_pretrained['perplexity']):.4f})")
print(f"- Mean Context Length: Pre-trained: {val_results_pretrained['mean_context_length']:.2f} vs. Fine-tuned: {val_results_finetuned['mean_context_length']:.2f}")

print("\nOOD Set:")
print(f"- EM: Pre-trained: {ood_results_pretrained['exact_match']:.4f} vs. Fine-tuned: {ood_results_finetuned['exact_match']:.4f} (Difference: {(ood_results_finetuned['exact_match'] - ood_results_pretrained['exact_match']):.4f})")
print(f"- F1: Pre-trained: {ood_results_pretrained['f1']:.4f} vs. Fine-tuned: {ood_results_finetuned['f1']:.4f} (Difference: {(ood_results_finetuned['f1'] - ood_results_pretrained['f1']):.4f})")
print(f"- Precision: Pre-trained: {ood_results_pretrained['precision']:.4f} vs. Fine-tuned: {ood_results_finetuned['precision']:.4f} (Difference: {(ood_results_finetuned['precision'] - ood_results_pretrained['precision']):.4f})")
print(f"- Recall: Pre-trained: {ood_results_pretrained['recall']:.4f} vs. Fine-tuned: {ood_results_finetuned['recall']:.4f} (Difference: {(ood_results_finetuned['recall'] - ood_results_pretrained['recall']):.4f})")
print(f"- BLEU: Pre-trained: {ood_results_pretrained['bleu']:.4f} vs. Fine-tuned: {ood_results_finetuned['bleu']:.4f} (Difference: {(ood_results_finetuned['bleu'] - ood_results_pretrained['bleu']):.4f})")
print(f"- ROUGE-1: Pre-trained: {ood_results_pretrained['rouge1']:.4f} vs. Fine-tuned: {ood_results_finetuned['rouge1']:.4f} (Difference: {(ood_results_finetuned['rouge1'] - ood_results_pretrained['rouge1']):.4f})")
print(f"- ROUGE-L: Pre-trained: {ood_results_pretrained['rougeL']:.4f} vs. Fine-tuned: {ood_results_finetuned['rougeL']:.4f} (Difference: {(ood_results_finetuned['rougeL'] - ood_results_pretrained['rougeL']):.4f})")
print(f"- METEOR: Pre-trained: {ood_results_pretrained['meteor']:.4f} vs. Fine-tuned: {ood_results_finetuned['meteor']:.4f} (Difference: {(ood_results_finetuned['meteor'] - ood_results_pretrained['meteor']):.4f})")
print(f"- Perplexity: Pre-trained: {ood_results_pretrained['perplexity']:.4f} vs. Fine-tuned: {ood_results_finetuned['perplexity']:.4f} (Difference: {(ood_results_finetuned['perplexity'] - ood_results_pretrained['perplexity']):.4f})")
print(f"- Mean Context Length: Pre-trained: {ood_results_pretrained['mean_context_length']:.2f} vs. Fine-tuned: {ood_results_finetuned['mean_context_length']:.2f}")

with open(f"{drive_path}/evaluation_results.txt", "w") as f:
    f.write("Pre-trained Validation Results:\n")
    f.write(str(val_results_pretrained) + "\n\n")
    f.write("Pre-trained OOD Validation Results:\n")
    f.write(str(ood_results_pretrained) + "\n\n")
    f.write("Fine-tuned Validation Results:\n")
    f.write(str(val_results_finetuned) + "\n\n")
    f.write("Fine-tuned OOD Validation Results:\n")
    f.write(str(ood_results_finetuned) + "\n\n")
    f.write("Comparative Analysis (Pre-trained vs. Fine-tuned):\n")
    f.write("\nValidation Set:\n")
    f.write(f"- EM: Pre-trained: {val_results_pretrained['exact_match']:.4f} vs. Fine-tuned: {val_results_finetuned['exact_match']:.4f}\n")
    f.write(f"- F1: Pre-trained: {val_results_pretrained['f1']:.4f} vs. Fine-tuned: {val_results_finetuned['f1']:.4f}\n")
    f.write(f"- Precision: Pre-trained: {val_results_pretrained['precision']:.4f} vs. Fine-tuned: {val_results_finetuned['precision']:.4f}\n")
    f.write(f"- Recall: Pre-trained: {val_results_pretrained['recall']:.4f} vs. Fine-tuned: {val_results_finetuned['recall']:.4f}\n")
    f.write(f"- BLEU: Pre-trained: {val_results_pretrained['bleu']:.4f} vs. Fine-tuned: {val_results_finetuned['bleu']:.4f}\n")
    f.write(f"- ROUGE-1: Pre-trained: {val_results_pretrained['rouge1']:.4f} vs. Fine-tuned: {val_results_finetuned['rouge1']:.4f}\n")
    f.write(f"- ROUGE-L: Pre-trained: {val_results_pretrained['rougeL']:.4f} vs. Fine-tuned: {val_results_finetuned['rougeL']:.4f}\n")
    f.write(f"- METEOR: Pre-trained: {val_results_pretrained['meteor']:.4f} vs. Fine-tuned: {val_results_finetuned['meteor']:.4f}\n")
    f.write(f"- Perplexity: Pre-trained: {val_results_pretrained['perplexity']:.4f} vs. Fine-tuned: {val_results_finetuned['perplexity']:.4f}\n")
    f.write(f"- Mean Context Length: Pre-trained: {val_results_pretrained['mean_context_length']:.2f} vs. Fine-tuned: {val_results_finetuned['mean_context_length']:.2f}\n")
    f.write("\nOOD Set:\n")
    f.write(f"- EM: Pre-trained: {ood_results_pretrained['exact_match']:.4f} vs. Fine-tuned: {ood_results_finetuned['exact_match']:.4f}\n")
    f.write(f"- F1: Pre-trained: {ood_results_pretrained['f1']:.4f} vs. Fine-tuned: {ood_results_finetuned['f1']:.4f}\n")
    f.write(f"- Precision: Pre-trained: {ood_results_pretrained['precision']:.4f} vs. Fine-tuned: {ood_results_finetuned['precision']:.4f}\n")
    f.write(f"- Recall: Pre-trained: {ood_results_pretrained['recall']:.4f} vs. Fine-tuned: {ood_results_finetuned['recall']:.4f}\n")
    f.write(f"- BLEU: Pre-trained: {ood_results_pretrained['bleu']:.4f} vs. Fine-tuned: {ood_results_finetuned['bleu']:.4f}\n")
    f.write(f"- ROUGE-1: Pre-trained: {ood_results_pretrained['rouge1']:.4f} vs. Fine-tuned: {ood_results_finetuned['rouge1']:.4f}\n")
    f.write(f"- ROUGE-L: Pre-trained: {ood_results_pretrained['rougeL']:.4f} vs. Fine-tuned: {ood_results_finetuned['rougeL']:.4f}\n")
    f.write(f"- METEOR: Pre-trained: {ood_results_pretrained['meteor']:.4f} vs. Fine-tuned: {ood_results_finetuned['meteor']:.4f}\n")
    f.write(f"- Perplexity: Pre-trained: {ood_results_pretrained['perplexity']:.4f} vs. Fine-tuned: {ood_results_finetuned['perplexity']:.4f}\n")
    f.write(f"- Mean Context Length: Pre-trained: {ood_results_pretrained['mean_context_length']:.2f} vs. Fine-tuned: {ood_results_finetuned['mean_context_length']:.2f}\n")

Mounted at /content/drive
Loading SQuAD v2 dataset...


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.48.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Generating fine-tuned answers...
Evaluating pre-trained model...
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.48.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colore

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_


Validation (Pre-trained) Evaluation Results:
- Exact Match (EM): 0.6000
- F1 Score: 68.0000
- Precision: 0.6000
- Recall: 0.6000
- BLEU Score: 0.1344
- ROUGE-1 F1: 0.6800
- ROUGE-L F1: 0.6800
- METEOR Score: 0.4092
- Mean Perplexity: 8.9766
- Mean Context Length: 113.00
Sample Raw Predictions: ['Based on the following context, answer the question in one word or a short phrase:\n\nContext: The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic ide

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_


OOD (Pre-trained) Evaluation Results:
- Exact Match (EM): 0.4000
- F1 Score: 56.0000
- Precision: 0.4000
- Recall: 0.4000
- BLEU Score: 0.1149
- ROUGE-1 F1: 0.5600
- ROUGE-L F1: 0.5600
- METEOR Score: 0.3862
- Mean Perplexity: 10.5531
- Mean Context Length: 113.00
Sample Raw Predictions: ['Based on the following context, answer the question in one word or a short phrase:\n\nContext: They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.. Through generations of assimilation and mixing with the native Fr

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_


Validation (Fine-tuned) Evaluation Results:
- Exact Match (EM): 0.4000
- F1 Score: 61.3333
- Precision: 0.4000
- Recall: 0.4000
- BLEU Score: 0.2483
- ROUGE-1 F1: 0.5943
- ROUGE-L F1: 0.5943
- METEOR Score: 0.4071
- Mean Perplexity: 6.4883
- Mean Context Length: 113.00
Sample Raw Predictions: ['Based on the following context, answer the question in one word or a short phrase:\n\nContext: The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic iden

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_


OOD (Fine-tuned) Evaluation Results:
- Exact Match (EM): 0.2000
- F1 Score: 49.4286
- Precision: 0.2000
- Recall: 0.2000
- BLEU Score: 0.0611
- ROUGE-1 F1: 0.4832
- ROUGE-L F1: 0.4832
- METEOR Score: 0.2575
- Mean Perplexity: 7.6508
- Mean Context Length: 113.00
Sample Raw Predictions: ['Based on the following context, answer the question in one word or a short phrase:\n\nContext: They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.. Through generations of assimilation and mixing with the native Fran