# Baseline Response Generation (Unsloth Version)

Uses **Unsloth** for 2-5x faster inference.

**Environment**: Google Colab T4 GPU

## 1. Install Unsloth

In [None]:
%%capture
!pip install unsloth xformers

In [None]:
import torch
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
import json
import os
from datetime import datetime
from tqdm import tqdm
import gc

## 2. Configuration

In [None]:
CONFIG = {
    "model_name": "unsloth/Meta-Llama-3.1-8B-Instruct",
    "input_file": "data/processed/counsel_chat_augmented.jsonl",
    "output_file": "data/baseline/responses_augmented.jsonl",
    "checkpoint_file": "data/baseline/checkpoint.json",
    "batch_size": 8,
    "max_new_tokens": 256,
    "temperature": 0.7,
    "top_p": 0.9,
    "do_sample": True,
    "checkpoint_freq": 100
}

os.makedirs("data/baseline", exist_ok=True)
os.makedirs("data/processed", exist_ok=True)
print(f"Config loaded. batch_size={CONFIG['batch_size']}")

## 3. Load Dataset

In [None]:
!git clone https://github.com/yuchangyuan1/6895_project_Agent.git temp_repo
!cp temp_repo/data/processed/counsel_chat_augmented.jsonl data/processed/
!rm -rf temp_repo
print("Augmented dataset loaded!")

In [None]:
def load_dataset(filepath: str) -> list:
    records = []
    with open(filepath, "r", encoding="utf-8") as f:
        for line in f:
            records.append(json.loads(line.strip()))
    print(f"Loaded {len(records)} records")
    return records

dataset = load_dataset(CONFIG["input_file"])

## 4. Load Model

In [None]:
from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=CONFIG["model_name"],
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

FastLanguageModel.for_inference(model)
print(f"Model loaded! GPU: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

## 5. Generation Functions

In [None]:
def create_baseline_prompt(question: str) -> str:
    return f"""You are a helpful assistant. Please respond to the following question:

{question}"""


def format_for_llama(prompt: str, tokenizer) -> str:
    messages = [{"role": "user", "content": prompt}]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


def generate_batch(prompts: list, model, tokenizer, config: dict) -> list:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=config["max_new_tokens"],
            temperature=config["temperature"],
            top_p=config["top_p"],
            do_sample=config["do_sample"],
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id
        )

    responses = []
    for i, output in enumerate(outputs):
        input_len = inputs["input_ids"][i].shape[0]
        response = tokenizer.decode(output[input_len:], skip_special_tokens=True).strip()
        responses.append(response)

    return responses


def load_checkpoint(checkpoint_file: str) -> int:
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, "r") as f:
            return json.load(f).get("last_index", 0)
    return 0


def save_checkpoint(checkpoint_file: str, last_index: int):
    with open(checkpoint_file, "w") as f:
        json.dump({"last_index": last_index, "timestamp": str(datetime.now())}, f)

## 6. Run Generation

In [None]:
# Optional: Delete old files to start fresh
# !rm -f data/baseline/checkpoint.json data/baseline/responses.jsonl
# print("Old files deleted!")

In [None]:
def run_baseline_generation(dataset: list, model, tokenizer, config: dict):
    output_file = config["output_file"]
    checkpoint_file = config["checkpoint_file"]
    batch_size = config["batch_size"]
    checkpoint_freq = config["checkpoint_freq"]

    start_index = load_checkpoint(checkpoint_file)
    if start_index > 0:
        print(f"Resuming from index {start_index}")

    mode = "a" if start_index > 0 else "w"
    total = len(dataset)

    with open(output_file, mode, encoding="utf-8") as f:
        for i in tqdm(range(start_index, total, batch_size), desc="Generating"):
            batch = dataset[i:min(i + batch_size, total)]

            prompts = [
                format_for_llama(create_baseline_prompt(r["question"]), tokenizer)
                for r in batch
            ]

            try:
                responses = generate_batch(prompts, model, tokenizer, config)

                for record, response in zip(batch, responses):
                    result = {
                        "id": record["id"],
                        "question": record["question"],
                        "original_answer": record["answer"],
                        "baseline_response": response,
                        "topic": record.get("topic", "general")
                    }
                    f.write(json.dumps(result, ensure_ascii=False) + "\n")
                f.flush()

            except Exception as e:
                print(f"Error at batch {i}: {e}")
                save_checkpoint(checkpoint_file, i)
                raise

            if (i + batch_size) % checkpoint_freq == 0:
                save_checkpoint(checkpoint_file, i + batch_size)

    save_checkpoint(checkpoint_file, total)
    print(f"Done! Output: {output_file}")

run_baseline_generation(dataset, model, tokenizer, CONFIG)

## 7. Verify Results

In [None]:
!wc -l data/baseline/responses.jsonl

In [None]:
def inspect_results(filepath: str, n: int = 2):
    with open(filepath, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if i >= n:
                break
            record = json.loads(line)
            print(f"\n{'='*50}")
            print(f"Q: {record['question'][:150]}...")
            print(f"Baseline: {record['baseline_response'][:200]}...")

inspect_results(CONFIG["output_file"])

## 8. Download

In [None]:
from google.colab import files
files.download(CONFIG["output_file"])

## 9. Cleanup

In [None]:
del model, tokenizer
torch.cuda.empty_cache()
gc.collect()
print("Cleanup complete!")