In [None]:
!pip install -q -U transformers accelerate bitsandbytes

In [None]:
!pip install -q -U datasets peft trl

# QA Generation

In [None]:
#!/usr/bin/env python3
"""
Q&A Fine-Tuning Dataset Generation with Batch Processing and Checkpointing
Processes all text chunks efficiently on free Colab GPU
"""

import torch
import json
import os
from getpass import getpass
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from tqdm import tqdm


In [None]:
# SETUP: Authentication and Model Configuration

try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
except ImportError:
    hf_token = getpass("Enter your Hugging Face Access Token: ")

login(token=hf_token)

MODEL_ID = "google/gemma-2b-it"
CLEAN_TEXT_PATH = "/content/master_dataset_v4.txt"
OUTPUT_DATASET_PATH = "/content/qa_finetuning_dataset.jsonl"
CHECKPOINT_FILE = "/content/generation_checkpoint.txt"

# Configuration for batch processing
BATCH_SIZE = 100 
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 150

print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    load_in_4bit=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

hf_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    batch_size=4,
)

print("Model loaded successfully!\n")

In [None]:
# TEXT PROCESSING: Load and Chunk Text

def create_text_chunks(text, chunk_size=1500, overlap=150):
    """Split text into overlapping chunks"""
    chunks = []
    start = 0
    while start < len(text):
        end = start + chunk_size
        chunks.append(text[start:end])
        start += chunk_size - overlap
    return chunks

print("Loading and chunking text...")
with open(CLEAN_TEXT_PATH, "r", encoding="utf-8") as f:
    clean_text = f.read()

all_text_chunks = create_text_chunks(clean_text)
print(f"Total chunks available: {len(all_text_chunks)}\n")

In [None]:
# QA GENERATION: Prompt and Generation Function


PROMPT_TEMPLATE = """<s>[INST]
You are an expert in creating educational data. Your task is to read the provided text from a computer science textbook and generate three high-quality question-answer pairs.

Your goal is to emulate the style of the Stanford Question Answering Dataset (SQuAD). This means:
1.  **Strictly Grounded:** The answer to every question must be a direct quote or a very close paraphrase of a sentence found within the provided text.
2.  **Natural Questions:** The questions should be phrased as a student would naturally ask them.
3.  **Concise Answers:** The answers should be as short as possible while still being comprehensive.

Your final output must be ONLY a valid JSON list `[...]` containing three JSON objects `{{...}}`, where each object has a "question" key and an "answer" key. Do not add any other text, explanations, or markdown.

**Textbook Chunk:**
---
{text_chunk}
---
[/INST]
"""

def generate_qa_pairs(text_chunk):
    """Generate Q&A pairs for a given text chunk"""
    prompt = PROMPT_TEMPLATE.format(text_chunk=text_chunk)
    try:
        response = hf_pipeline(prompt)
        response_text = response[0]['generated_text'].split('[/INST]')[-1].strip()

        # Extract JSON from response
        start_index = response_text.find('[')
        end_index = response_text.rfind(']')

        if start_index != -1 and end_index != -1:
            json_response = response_text[start_index:end_index + 1]
            return json.loads(json_response)
        else:
            return None
    except Exception as e:
        return None


In [None]:
# CHECKPOINT MANAGEMENT

def load_checkpoint():
    """Load the last checkpoint index"""
    if os.path.exists(CHECKPOINT_FILE):
        try:
            with open(CHECKPOINT_FILE, 'r') as f:
                return int(f.read().strip())
        except:
            return 0
    return 0

def save_checkpoint(chunk_index):
    """Save current checkpoint"""
    with open(CHECKPOINT_FILE, 'w') as f:
        f.write(str(chunk_index))

def count_generated_pairs():
    """Count existing Q&A pairs in output file"""
    if not os.path.exists(OUTPUT_DATASET_PATH):
        return 0
    with open(OUTPUT_DATASET_PATH, 'r', encoding='utf-8') as f:
        return len(f.readlines())

In [None]:
# MAIN GENERATION LOOP WITH BATCH PROCESSING

def main():
    """Main function to process all text chunks with batching and checkpointing"""

    # Initialize or resume from checkpoint
    start_chunk = load_checkpoint()
    existing_pairs = count_generated_pairs()

    if start_chunk > 0:
        print(f"Resuming from chunk {start_chunk}")
        print(f"Existing pairs in output: {existing_pairs}\n")
    else:
        print("Starting fresh generation\n")
        with open(OUTPUT_DATASET_PATH, 'w', encoding='utf-8') as f:
            pass

    total_chunks = len(all_text_chunks)
    total_pairs_generated = 0

    num_batches = (total_chunks + BATCH_SIZE - 1) // BATCH_SIZE

    print(f"Total chunks to process: {total_chunks}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Total batches: {num_batches}\n")

    for batch_idx in range(num_batches):
        batch_start_idx = batch_idx * BATCH_SIZE
        batch_end_idx = min(batch_start_idx + BATCH_SIZE, total_chunks)

        if batch_end_idx <= start_chunk:
            continue

        if batch_start_idx < start_chunk:
            batch_start_idx = start_chunk

        batch_chunks = all_text_chunks[batch_start_idx:batch_end_idx]

        print(f"\n{'='*70}")
        print(f"Batch {batch_idx + 1}/{num_batches} (chunks {batch_start_idx}-{batch_end_idx-1})")
        print(f"{'='*70}")

        batch_pairs_count = 0

        # Generate Q&A pairs for this batch
        for chunk_idx, chunk in enumerate(tqdm(batch_chunks, desc="Generating Q&A pairs")):
            absolute_chunk_idx = batch_start_idx + chunk_idx
            qa_pairs = generate_qa_pairs(chunk)

            if qa_pairs:
                for pair in qa_pairs:
                    with open(OUTPUT_DATASET_PATH, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(pair, ensure_ascii=False) + '\n')
                    batch_pairs_count += 1
                    total_pairs_generated += 1

            save_checkpoint(absolute_chunk_idx + 1)

        print(f"\nBatch {batch_idx + 1} Summary:")
        print(f"  - Chunks processed: {batch_start_idx} to {batch_end_idx-1}")
        print(f"  - Q&A pairs generated in this batch: {batch_pairs_count}")
        print(f"  - Total pairs generated so far: {total_pairs_generated}")
        print(f"  - Output file: {OUTPUT_DATASET_PATH}")

        # Clear GPU cache between batches to prevent memory buildup
        torch.cuda.empty_cache()
        print("  - GPU cache cleared")

    # Final summary
    print(f"\n{'='*70}")
    print("GENERATION COMPLETE!")
    print(f"{'='*70}")

    final_pair_count = count_generated_pairs()
    print(f"Total Q&A pairs generated: {final_pair_count}")
    print(f"Dataset saved to: {OUTPUT_DATASET_PATH}")
    print(f"\nDataset Statistics:")
    print(f"  - Total chunks processed: {total_chunks}")
    print(f"  - Total Q&A pairs: {final_pair_count}")
    print(f"  - Average pairs per chunk: {final_pair_count / total_chunks:.2f}")

    # Clean up checkpoint file after successful completion
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)
        print("\nCheckpoint file cleaned up.")

    print("\nReady for fine-tuning!")

In [None]:
# SCRIPT ENTRY POINT

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\nGeneration interrupted by user.")
        print(f"Progress saved. Run again to resume from checkpoint.")
        print(f"Current checkpoint: {load_checkpoint()}")
    except Exception as e:
        print(f"\n\nError occurred: {str(e)}")
        print(f"Progress saved. Run again to resume from checkpoint.")
        print(f"Current checkpoint: {load_checkpoint()}")
        raise

# Fine Tune with QLoRA

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from getpass import getpass
from huggingface_hub import login

try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
except ImportError:
    hf_token = getpass("Enter your Hugging Face Access Token: ")

login(token=hf_token)

BASE_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
DATASET_PATH = "/content/qa_finetuning_dataset.jsonl"
NEW_ADAPTERS_ID = "algo-tutor-mistral-7b-adapters"

In [None]:
import re
import json
from datasets import load_dataset

INPUT_FILE = "/content/qa_finetuning_dataset.jsonl"
OUTPUT_FILE = "/content/qa_finetuning_dataset_cleaned.jsonl"

with open(INPUT_FILE, "r", encoding="utf-8") as fin:
    text = fin.read()

text = text.replace("\\n", "\n")

json_candidates = re.findall(r'\{.*?\}', text, re.DOTALL)

valid_count = 0
with open(OUTPUT_FILE, "w", encoding="utf-8") as fout:
    for candidate in json_candidates:
        try:
            obj = json.loads(candidate)
            q, a = obj.get("question"), obj.get("answer")

            if not isinstance(q, str) or not q.strip():
                continue

            if isinstance(a, list):
                a = " ".join(map(str, a))
            if not isinstance(a, str) or not a.strip():
                continue

            fout.write(json.dumps({"question": q.strip(), "answer": a.strip()}, ensure_ascii=False) + "\n")
            valid_count += 1
        except json.JSONDecodeError:
            continue

print(f"Cleaned {valid_count} valid Q&A pairs.")
print(f"Saved cleaned dataset to: {OUTPUT_FILE}")

In [None]:
dataset = load_dataset("json", data_files=OUTPUT_FILE, split="train")
print(f"Loaded dataset with {len(dataset)} examples.")

def format_instruction(sample):
    return f"""<s>[INST] {sample['question']} [/INST]\n{sample['answer']}</s>"""

In [None]:
!head qa_finetuning_dataset_cleaned.jsonl

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
)
model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
)
model = get_peft_model(model, peft_config)

training_arguments = SFTConfig(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    save_steps=50,
    logging_steps=10,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=True,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
)

def format_instruction(example):
    return f"<s>[INST] {example['question']} [/INST]\n{example['answer']}</s>"

def tokenize_function(example):
    text = format_instruction(example)
    return tokenizer(text, truncation=True, max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=False)

trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized_dataset,
    args=training_arguments,
    peft_config=peft_config,
)

print("Starting the QLoRA fine-tuning process...")
trainer.train()
print("Training complete.")

import shutil

print("Saving the trained LoRA adapters...")
trainer.model.save_pretrained(NEW_ADAPTERS_ID)

shutil.make_archive("algo-tutor-lora-adapters", 'zip', NEW_ADAPTERS_ID)
print("Adapters saved and zipped to 'algo-tutor-lora-adapters.zip'.")
