Sakshi Lathi, 1234549838

# STaR: Self-Taught Reasoner Implementation

## Bootstrapping Logical Reasoning Through Self-Improvement

**Model Used:** Llama-3.2-3B-Instruct
**Dataset:** GSM8K (training and evaluation sets)
**Techniques Applied:**

* Zero-Shot Chain-of-Thought (CoT)
* Vanilla Supervised Fine-Tuning (SFT)
* STaR (Self-Taught Reasoner)

---

This project focuses on reproducing the **STaR: Self-Taught Reasoner** approach, which enhances a model's reasoning ability by iteratively generating and refining rationales. Using **Llama-3.2-3B-Instruct**, we experiment with three paradigms: a baseline zero-shot reasoning model, a supervised fine-tuning setup using provided rationales, and the STaR method, where the model teaches itself through bootstrapped rationales derived from its own outputs. The GSM8K dataset serves as both the training corpus for rationale generation and the benchmark for final evaluation.




In [None]:
# Installing required packages
!pip install -q datasets transformers torch accelerate peft bitsandbytes tqdm
!pip install torch

In [None]:
import torch
print(torch.__version__)
print("CUDA available:", torch.cuda.is_available())


In [None]:
import os
import json
import torch
import re
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq
)
from tqdm import tqdm
import random
from typing import List, Dict, Tuple
import numpy as np

# Setting random seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
class Config:
    # Model and dataset
    MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
    DATASET_NAME = "openai/gsm8k"
    DATASET_CONFIG = "main"

    # Training parameters
    BATCH_SIZE = 2 # batch size is reduced
    LEARNING_RATE = 2e-5
    NUM_EPOCHS = 3
    MAX_LENGTH = 512
    GRADIENT_ACCUMULATION_STEPS = 16 # gradient accumulation is increased
    WARMUP_STEPS = 100

    # STaR specific
    STAR_ITERATIONS = 1
    MAX_NEW_TOKENS = 400
    TEMPERATURE = 0.7

    # Testing
    TRAIN_SAMPLE_SIZE = 200
    TEST_SAMPLE_SIZE = 50

    # Output directories
    OUTPUT_DIR = "home/slathi//star_output"
    VANILLA_SFT_DIR = "home/slathi//vanilla_sft_output"
    STAR_DIR = "home/slathi//star_iterations"
    DATA_DIR = "home/slathi//star_data"
config = Config()

# Create directories
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.VANILLA_SFT_DIR, exist_ok=True)
os.makedirs(config.STAR_DIR, exist_ok=True)
os.makedirs(config.DATA_DIR, exist_ok=True)

print("Configuration:")
print(f"  Model: {config.MODEL_NAME}")
print(f"  Dataset: {config.DATASET_NAME}")
print(f"  STaR Iterations: {config.STAR_ITERATIONS}")
print(f"  Training Epochs: {config.NUM_EPOCHS}")
print(f"  Train Sample Size: {config.TRAIN_SAMPLE_SIZE}") # Added train sample size to print
print(f"  Test Sample Size: {config.TEST_SAMPLE_SIZE}")

## Prompts

### Rationale Generation Prompt (without hint):
```
Question: {question}
Let's think step by step to solve this problem.
```

### Rationale Generation with Hint (rationalization):
```
Question: {question}
The answer is {correct_answer}. Let's think step by step to explain how we get this answer.
```

### Zero-Shot CoT Prompt:
```
Question: {question}
Let's think step by step.
```

In [None]:
# Prompt templates
PROMPT_WITHOUT_HINT = """Question: {question}
Let's think step by step to solve this problem."""

PROMPT_WITH_HINT = """Question: {question}
The answer is {answer}. Let's think step by step to explain how we get this answer."""

ZERO_SHOT_COT_PROMPT = """Question: {question}
Let's think step by step."""

In [None]:
# Log in to Hugging Face
from huggingface_hub import notebook_login

print("Logging in to Hugging Face...")
notebook_login()

In [None]:
# Load GSM8K dataset
print("Loading GSM8K dataset...")
dataset = load_dataset(config.DATASET_NAME, config.DATASET_CONFIG)
train_dataset = dataset['train']
test_dataset = dataset['test']

# Sample for testing if specified
if config.TRAIN_SAMPLE_SIZE:
    train_dataset = train_dataset.select(range(config.TRAIN_SAMPLE_SIZE))

print(f"Train size: {len(train_dataset)}")
print(f"Test size: {len(test_dataset)}")
print(f"\nExample from train set:")
print(f"Question: {train_dataset[0]['question']}")
print(f"Answer: {train_dataset[0]['answer']}")

In [None]:
# Load tokenizer and model
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
    config.MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_8bit=True
)

print("Model loaded successfully!")
print(f"Device: {model.device}")

In [None]:
# Utility Functions
def extract_answer(text: str) -> str:

    # GSM8K format: answer is after ####
    if '####' in text:
        answer = text.split('####')[-1].strip()
        # Extract just the number
        numbers = re.findall(r'-?\d+\.?\d*', answer)
        if numbers:
            return numbers[0]

    # Try to extract from common patterns
    patterns = [
        r'(?:the answer is|answer:|final answer is|answer =)\s*\$?\s*(-?\d+\.?\d*)',
        r'\$\s*(-?\d+\.?\d*)',
        r'(-?\d+\.?\d*)\s*$',
    ]

    for pattern in patterns:
        matches = re.findall(pattern, text.lower())
        if matches:
            return matches[-1]

    return ""

def check_answer_correctness(predicted: str, ground_truth: str) -> bool:

    pred_answer = extract_answer(predicted)
    gt_answer = extract_answer(ground_truth)

    if not pred_answer or not gt_answer:
        return False

    try:
        return float(pred_answer) == float(gt_answer)
    except ValueError:
        return pred_answer == gt_answer

def generate_response(model, tokenizer, prompt: str, max_new_tokens: int = 400) -> str:

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config.MAX_LENGTH)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=config.TEMPERATURE,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response.strip()

# Test the functions
test_text = "So the total is 25 + 10 = 35. #### 35"
print(f"Test extract_answer: '{extract_answer(test_text)}'")
print(f"Test correctness check: {check_answer_correctness(test_text, '#### 35')}")

In [None]:
def evaluate_zero_shot_cot(model, tokenizer, test_data, num_samples=None):

    print("\n" + "="*50)
    print("Evaluating Zero-Shot CoT...")
    print("="*50)

    if num_samples:
        test_data = test_data.select(range(min(num_samples, len(test_data))))

    correct = 0
    total = len(test_data)
    results = []

    for i, example in enumerate(tqdm(test_data, desc="Zero-Shot CoT")):
        question = example['question']
        ground_truth = example['answer']

        prompt = ZERO_SHOT_COT_PROMPT.format(question=question)
        response = generate_response(model, tokenizer, prompt)

        is_correct = check_answer_correctness(response, ground_truth)
        if is_correct:
            correct += 1

        results.append({
            'question': question,
            'ground_truth': ground_truth,
            'response': response,
            'correct': is_correct
        })

        if (i + 1) % 50 == 0:
            print(f"Progress: {i+1}/{total}, Accuracy: {correct/(i+1)*100:.2f}%")

    accuracy = correct / total * 100
    print(f"\nZero-Shot CoT Accuracy: {accuracy:.2f}% ({correct}/{total})")

    # Save results
    with open(os.path.join(config.OUTPUT_DIR, 'zero_shot_cot_results.json'), 'w') as f:
        json.dump({'accuracy': accuracy, 'correct': correct, 'total': total, 'results': results}, f, indent=2)

    return accuracy

In [None]:
# Data Preparation for Vanilla SFT
def prepare_vanilla_sft_data(train_data):

    print("\nPreparing Vanilla SFT dataset...")

    sft_data = []
    for example in tqdm(train_data, desc="Processing training data"):
        question = example['question']
        answer = example['answer']  # Contains reasoning + #### + final answer

        # Create training example
        prompt = f"Question: {question}\nLet's think step by step to solve this problem.\n"
        completion = answer

        sft_data.append({
            'prompt': prompt,
            'completion': completion,
            'text': prompt + completion
        })

    print(f"Prepared {len(sft_data)} training examples for Vanilla SFT")

    # Save data
    with open(os.path.join(config.DATA_DIR, 'vanilla_sft_data.json'), 'w') as f:
        json.dump(sft_data, f, indent=2)

    return sft_data

In [None]:
# STaR Data Generation
def generate_star_dataset(model, tokenizer, train_data, iteration: int = 0):

    print(f"\n" + "="*50)
    print(f"Generating STaR Dataset - Iteration {iteration}")
    print("="*50)

    star_data = []
    stats = {
        'correct_first_attempt': 0,
        'correct_with_hint': 0,
        'total': len(train_data)
    }

    for i, example in enumerate(tqdm(train_data, desc=f"STaR Iteration {iteration}")):
        question = example['question']
        ground_truth = example['answer']
        gt_answer = extract_answer(ground_truth)

        # Step 1: Generate rationale without hint
        prompt_no_hint = PROMPT_WITHOUT_HINT.format(question=question)
        rationale_no_hint = generate_response(model, tokenizer, prompt_no_hint)

        # Check if answer is correct
        is_correct = check_answer_correctness(rationale_no_hint, ground_truth)

        if is_correct:
            # Answer is correct, use this rationale
            star_data.append({
                'prompt': prompt_no_hint,
                'completion': rationale_no_hint,
                'text': prompt_no_hint + '\n' + rationale_no_hint,
                'source': 'correct_first_attempt'
            })
            stats['correct_first_attempt'] += 1
        else:
            # Step 2: Answer is wrong, generate with hint (rationalization)
            prompt_with_hint = PROMPT_WITH_HINT.format(question=question, answer=gt_answer)
            rationale_with_hint = generate_response(model, tokenizer, prompt_with_hint)

            # CRITICAL: Train as if the model generated this without the hint
            star_data.append({
                'prompt': PROMPT_WITHOUT_HINT.format(question=question),
                'completion': rationale_with_hint,
                'text': PROMPT_WITHOUT_HINT.format(question=question) + '\n' + rationale_with_hint,
                'source': 'rationalization'
            })
            stats['correct_with_hint'] += 1

        if (i + 1) % 100 == 0:
            print(f"Progress: {i+1}/{len(train_data)}")
            print(f"  Correct first: {stats['correct_first_attempt']}, With hint: {stats['correct_with_hint']}")

    print(f"\nSTaR Dataset Generation Complete!")
    print(f"Total examples: {len(star_data)}")
    print(f"Statistics: {stats}")

    # Save dataset
    save_path = os.path.join(config.DATA_DIR, f'star_data_iter_{iteration}.json')
    with open(save_path, 'w') as f:
        json.dump({'data': star_data, 'stats': stats}, f, indent=2)

    return star_data, stats

In [None]:
# Training
def prepare_dataset_for_training(data_list, tokenizer):

    def tokenize_function(examples):
        # Tokenize the full text (prompt + completion)
        model_inputs = tokenizer(
            examples['text'],
            max_length=config.MAX_LENGTH,
            truncation=True,
            padding='max_length'
        )

        # Labels are the same as input_ids for causal LM
        model_inputs['labels'] = model_inputs['input_ids'].copy()

        return model_inputs

    # Convert to dataset format
    dataset = Dataset.from_dict({
        'text': [item['text'] for item in data_list]
    })

    # Tokenize
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Tokenizing dataset"
    )

    return tokenized_dataset

from peft import LoraConfig, get_peft_model

def train_model(train_data, output_dir, model_name="model"):

    print(f"\n" + "="*50)
    print(f"Training {model_name}...")
    print("="*50)

    # Load fresh model from base checkpoint
    # Load model with 8-bit quantization
    model = AutoModelForCausalLM.from_pretrained(
        config.MODEL_NAME,
        torch_dtype=torch.float16,
        device_map={'': 0},
        load_in_8bit=True
    )

    # Configure PEFT (LoRA)
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Get PEFT model
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()


    # Prepare dataset
    tokenized_dataset = prepare_dataset_for_training(train_data, tokenizer)
    print(f"Training on {len(tokenized_dataset)} examples")

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=config.NUM_EPOCHS,
        per_device_train_batch_size=config.BATCH_SIZE,
        gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
        learning_rate=config.LEARNING_RATE,
        warmup_steps=config.WARMUP_STEPS,
        logging_steps=50,
        save_strategy="epoch",
        save_total_limit=2,
        fp16=True,
        report_to="none",
        remove_unused_columns=False,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
    )

    # Train
    print(f"Starting training...")
    trainer.train()

    # Save model
    trainer.save_model(output_dir)
    print(f"Model saved to {output_dir}")

    # Clear memory
    del model
    torch.cuda.empty_cache()

    return trainer.model # Return the trained PEFT model

In [None]:
# Evaluation
def evaluate_model(model, tokenizer, test_data, model_name="model", num_samples=None):

    print(f"\n" + "="*50)
    print(f"Evaluating {model_name}...")
    print("="*50)

    if num_samples:
        test_data = test_data.select(range(min(num_samples, len(test_data))))

    correct = 0
    total = len(test_data)
    results = []

    for i, example in enumerate(tqdm(test_data, desc=f"Evaluating {model_name}")):
        question = example['question']
        ground_truth = example['answer']

        prompt = PROMPT_WITHOUT_HINT.format(question=question)
        response = generate_response(model, tokenizer, prompt)

        is_correct = check_answer_correctness(response, ground_truth)
        if is_correct:
            correct += 1

        results.append({
            'question': question,
            'ground_truth': ground_truth,
            'response': response,
            'correct': is_correct
        })

        if (i + 1) % 50 == 0:
            print(f"Progress: {i+1}/{total}, Accuracy: {correct/(i+1)*100:.2f}%")

    accuracy = correct / total * 100
    print(f"\n{model_name} Accuracy: {accuracy:.2f}% ({correct}/{total})")

    return accuracy, results

In [None]:
# Evaluate Zero-Shot CoT
zero_shot_accuracy = evaluate_zero_shot_cot(
    model,
    tokenizer,
    test_dataset,
    num_samples=config.TEST_SAMPLE_SIZE
)

In [None]:
# Prepare Vanilla SFT data
vanilla_sft_data = prepare_vanilla_sft_data(train_dataset)

# Train Vanilla SFT model
vanilla_model = train_model(
    vanilla_sft_data,
    config.VANILLA_SFT_DIR,
    model_name="Vanilla SFT"
)

# Evaluate Vanilla SFT
vanilla_accuracy, vanilla_results = evaluate_model(
    vanilla_model,
    tokenizer,
    test_dataset,
    model_name="Vanilla SFT",
    num_samples=config.TEST_SAMPLE_SIZE
)

# Save results
with open(os.path.join(config.OUTPUT_DIR, 'vanilla_sft_results.json'), 'w') as f:
    json.dump({
        'accuracy': vanilla_accuracy,
        'correct': sum(1 for r in vanilla_results if r['correct']),
        'total': len(vanilla_results),
        'results': vanilla_results
    }, f, indent=2)

# Clear memory
del vanilla_model
torch.cuda.empty_cache()

In [None]:
# STaR Iterative Training
star_results = []
current_model = model  # Start with base model

for iteration in range(config.STAR_ITERATIONS):
    print(f"\n\n{'='*60}")
    print(f"STaR ITERATION {iteration + 1}/{config.STAR_ITERATIONS}")
    print(f"{'='*60}\n")

    # Step 1: Generate STaR dataset using current model
    star_data, stats = generate_star_dataset(
        current_model,
        tokenizer,
        train_dataset,
        iteration=iteration
    )

    # Step 2: Train on the generated dataset
    iter_output_dir = os.path.join(config.STAR_DIR, f"iteration_{iteration}")
    new_model = train_model(
        star_data,
        iter_output_dir,
        model_name=f"STaR Iteration {iteration}"
    )

    # Clear previous model from memory
    if iteration > 0:
        del current_model
        torch.cuda.empty_cache()

    current_model = new_model

    # Step 3: Evaluate
    accuracy, results = evaluate_model(
        current_model,
        tokenizer,
        test_dataset,
        model_name=f"STaR Iteration {iteration}",
        num_samples=config.TEST_SAMPLE_SIZE
    )

    star_results.append({
        'iteration': iteration,
        'accuracy': accuracy,
        'stats': stats,
        'results': results
    })

    # Save iteration results
    with open(os.path.join(config.OUTPUT_DIR, f'star_iteration_{iteration}_results.json'), 'w') as f:
        json.dump(star_results[-1], f, indent=2)

    print(f"\nâœ“ Iteration {iteration} completed!")
    print(f"  Accuracy: {accuracy:.2f}%")

In [None]:
# Final results
final_results = {
    'Zero-Shot CoT': zero_shot_accuracy,
    'Vanilla SFT': vanilla_accuracy,
    'STaR': [result['accuracy'] for result in star_results]
}

print("\n" + "="*60)
print("FINAL RESULTS - GSM8K Test Set")
print("="*60)
print(f"\nTest samples: {config.TEST_SAMPLE_SIZE if config.TEST_SAMPLE_SIZE else 'Full test set'}")
print("\n" + "-"*60)
print(f"Zero-Shot CoT:  {final_results['Zero-Shot CoT']:>6.2f}%")
print(f"Vanilla SFT:    {final_results['Vanilla SFT']:>6.2f}%")
for i, acc in enumerate(final_results['STaR']):
    print(f"STaR Iter {i}:   {acc:>6.2f}%")
print("-"*60)

# Create results table
print("\n" + "="*60)
print("Results Table (Exact Match Accuracy)")
print("="*60)
print(f"\n| Method              | Accuracy (%) |")
print(f"|---------------------|--------------|")
print(f"| Zero-Shot CoT       | {final_results['Zero-Shot CoT']:>12.2f} |")
print(f"| Vanilla SFT         | {final_results['Vanilla SFT']:>12.2f} |")
for i, acc in enumerate(final_results['STaR']):
    print(f"| STaR Iteration {i}   | {acc:>12.2f} |")
print("\n")

# Save final results
with open(os.path.join(config.OUTPUT_DIR, 'final_results.json'), 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"All results saved to {config.OUTPUT_DIR}")

In [None]:
import matplotlib.pyplot as plt

# Plot results
methods = ['Zero-Shot CoT', 'Vanilla SFT'] + [f'STaR Iter {i}' for i in range(len(final_results['STaR']))]
accuracies = [final_results['Zero-Shot CoT'], final_results['Vanilla SFT']] + final_results['STaR']

plt.figure(figsize=(12, 6))
colors = ['#3498db', '#2ecc71'] + ['#e74c3c'] * len(final_results['STaR'])
bars = plt.bar(methods, accuracies, color=colors, alpha=0.8, edgecolor='black')

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.2f}%',
             ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.xlabel('Method', fontsize=12, fontweight='bold')
plt.ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
plt.title('GSM8K Test Set Performance Comparison', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_DIR, 'results_comparison.png'), dpi=300, bbox_inches='tight')
plt.show()



# Plot STaR progression
if len(final_results['STaR']) > 1:
    plt.figure(figsize=(10, 6))
    iterations = list(range(len(final_results['STaR'])))
    plt.plot(iterations, final_results['Zero-Shot CoT'], marker='o', linewidth=2, markersize=10, color='#e74c3c')
    plt.axhline(y=final_results['Vanilla SFT'], color='#2ecc71', linestyle='--', label='Vanilla SFT', linewidth=2)
    plt.axhline(y=final_results['STaR'], color='#3498db', linestyle='--', label='Zero-Shot CoT', linewidth=2)

    plt.xlabel('STaR Iteration', fontsize=12, fontweight='bold')
    plt.ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
    plt.title('STaR Accuracy Progression', fontsize=14, fontweight='bold')
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3, linestyle='--')
    plt.tight_layout()
    plt.savefig(os.path.join(config.OUTPUT_DIR, 'star_progression.png'), dpi=300, bbox_inches='tight')
    plt.show()

