### Unsloth

In [1]:
from unsloth import FastLanguageModel
import torch
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"
os.environ["OPENAI_API_KEY"] = "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"

fourbit_models = [
    "unsloth/Qwen3-1.7B-unsloth-bnb-4bit", # Qwen 14B 2x faster
    "unsloth/Qwen3-4B-unsloth-bnb-4bit",
    "unsloth/Qwen3-8B-unsloth-bnb-4bit",
    "unsloth/Qwen3-14B-unsloth-bnb-4bit",
    "unsloth/Qwen3-32B-unsloth-bnb-4bit",

    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/Phi-4",
    "unsloth/Llama-3.1-8B",
    "unsloth/Llama-3.2-3B",
    "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" # [NEW] We support TTS models!
] # More models at https://huggingface.co/unsloth

# FULL FINETUNING MODE - Using Qwen3-0.6B
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-0.6B",
    max_seq_length = 2048,   # Context length - can be longer, but uses more memory
    load_in_4bit = False,    # Must be False for full finetuning
    load_in_8bit = False,    # Must be False for full finetuning - will use BF16
    full_finetuning = True,  # ENABLED: Full finetuning instead of LoRA!
    # token = "hf_...",      # use one if using gated models
)

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


ü¶• Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.11.2: Fast Qwen3 patching. Transformers: 4.57.1.
   \\   /|    NVIDIA RTX A6000. Num GPUs = 4. Max memory: 47.402 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.
To enable float32 training, use `float32_mixed_precision = True` during FastLanguageModel.from_pretrained


**FULL FINETUNING MODE** - We are NOT using LoRA adapters. All model parameters will be updated during training.

In [2]:
# FULL FINETUNING - Skip LoRA adapter setup
# For full finetuning, we train all parameters directly without LoRA adapters
# The model is already prepared for full finetuning from the previous cell

print(f"Model loaded with full finetuning enabled.")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model loaded with full finetuning enabled.
Total parameters: 596,049,920
Trainable parameters: 596,049,920


<a name="Data"></a>
### Data Prep - Question Decomposition Dataset

We use a custom dataset of question decompositions for training. The dataset contains:
- Original multi-hop questions
- Decomposed single-hop questions with retrieval flags
- Each example is formatted as a user-assistant conversation

In [3]:
import json
from datasets import Dataset

# Load the question decomposition dataset
dataset_path = "/home/yigit/codebase/gsw-memory/playground/question_decomp_local/q_decomp_training_5_large.json"

print(f"Loading dataset from: {dataset_path}")
with open(dataset_path, 'r', encoding='utf-8') as f:
    decomposition_results = json.load(f)

print(f"Loaded {len(decomposition_results)} question decompositions")

# Convert to list format
decomposition_list = list(decomposition_results.values())

print(f"\nExample structure:")
print(f"Keys: {decomposition_list[0].keys()}")
print(f"\nFirst example:")
print(f"Question ID: {decomposition_list[0]['question_id']}")
print(f"Original Question: {decomposition_list[0]['original_question']}")
print(f"Decomposed Questions: {decomposition_list[0]['decomposed_questions']}")

Loading dataset from: /home/yigit/codebase/gsw-memory/playground/question_decomp_local/q_decomp_training_5_large.json
Loaded 2527 question decompositions

Example structure:
Keys: dict_keys(['question_id', 'original_question', 'decomposed_questions'])

First example:
Question ID: 2hop__42543_20093
Original Question: What year did the writer of Crazy Little Thing Called Love die?
Decomposed Questions: [{'question': 'Who wrote the song "Crazy Little Thing Called Love"?', 'requires_retrieval': True}, {'question': 'In what year did <ENTITY_Q1> die?', 'requires_retrieval': True}]


### Convert to Chat Format

Now we'll convert each example into chat format for training

In [4]:
def create_chat_messages(example):
    """
    Convert a single example into chat format for training.

    Args:
        example: Dict with 'original_question' and 'decomposed_questions' keys

    Returns:
        Dict with 'messages' key containing the chat-formatted data
    """
    original_question = example['original_question']
    decomposed_questions = example['decomposed_questions']

    # Serialize the decomposed questions to JSON format (this is what the model should output)
    assistant_response = json.dumps(
        {"questions": decomposed_questions},
        indent=4,
        ensure_ascii=False
    )

    # Create the instruction prompt for the user (same as used in question_decomp_lora_ft.py)
    user_prompt = f"""Your task is to break down a complex multi-hop question into the most efficient sequence of single-hop, **atomic** questions.

## Your Main Goal: Build Smart Bridges, Don't Just Collect Nouns
The most critical skill is to convert complex logical clauses (like "despite," "the country where," "the year before") into a single, powerful **bridging question**. This question should use a known entity as context to find the next one. Avoid finding all the entities separately and then trying to figure out how they connect.

---
## A Simple Analogy for Efficiency

**Question:** "What is the phone number of the mother of the tallest player on the Lakers?"

** Inefficient Path:**
1.  Who are the players on the Lakers?
2.  What are all their heights?
3.  Who is the mother of the tallest player? *(This step is a logical leap)*

** Efficient Path:**
1.  Who is the tallest player on the Lakers?
2.  Who is the mother of `<ENTITY_Q1>`?
3.  What is the phone number of `<ENTITY_Q2>`?

---
## How to Decompose a Question
This process follows a logical flow from high-level analysis to the fine-tuning of your question chain.

### 1. Analyze the Query's Components
First, break down the original question into its fundamental building blocks. Identify the core **entities** (people, places, organizations), their **properties** (attributes like rank, location, date), and the **relationships** that connect them.

### 2. Construct an Atomic Chain
Next, formulate a sequence of questions where each question retrieves a single fact.
* **Isolate Comparisons:** Don't ask "who is faster?" Ask for the specific rank or time of each person involved.
* **Link with Placeholders:** Use `<ENTITY_Qn>` to pass the answer from a previous question (`Qn`) into the next one.

### 3. Optimize for Efficiency and Precision
Your final goal is the **shortest and most direct path** to the answer.
* **Embed Constraints to Build Bridges:** If a piece of information is only a filter (like a date or location), embed it as a constraint in the next question instead of asking for it directly.
**Important note for bridges:** There can be no `<ENTITY_Qn>` in the first question if the nth question DOES NOT require retrieval.

## Formatting
Format each decomposed question as follows:

Question: [the question text]
Requires retrieval: [true/false]

And provide the response in the following JSON format:
{{
  "questions": [
    {{
      "question": "the decomposed question text",
      "requires_retrieval": "true/false"
    }}
  ]
}}

Examples:

Input: "What is the birth year of the spouse of the director of Casablanca?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Casablanca?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who was <ENTITY_Q1>'s spouse?",
            "requires_retrieval": True
        }},
        {{
            "question": "What is <ENTITY_Q2>'s birth year?",
            "requires_retrieval": True
        }}
    ]
}}

Input: "Which film has the director who is older, Dune or The Dark Knight?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Dune?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who directed The Dark Knight?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": False
        }}
    ]
}}


IMPORTANT:
    AVOID over-decomposition like this:
    DON'T break "Who is John Doe?" into:
    1. Who is John Doe? ‚Üí "English"
    2. When was <ENTITY_Q1> born? ‚Üí "When was English born?"

    DO ask directly: "When was John Doe born?"

Now decompose this question:
Input: "{original_question}"
Output:
"""

    # Create the chat messages in the format expected by chat models
    messages = [
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": assistant_response},
    ]

    return {"messages": messages}

print("Chat formatting function created!")

Chat formatting function created!


In [5]:
# Create HuggingFace Dataset from the decomposition list
raw_dataset = Dataset.from_list(decomposition_list)

print(f"Raw dataset info:")
print(raw_dataset)
print(f"\nColumn names: {raw_dataset.column_names}")

# Apply the preprocessing to create the final training dataset
training_dataset = raw_dataset.map(
    create_chat_messages,
    remove_columns=raw_dataset.column_names,  # Remove original columns, keep only 'messages'
    desc="Creating chat-formatted training data"
)

print(f"\nTraining dataset created!")
print(training_dataset)
print(f"\nColumn names: {training_dataset.column_names}")
print(f"\nFirst example messages (user prompt first 500 chars):")
print(training_dataset[0]['messages'][0]['content'][:500] + "...")
print(f"\nAssistant response (first 300 chars):")
print(training_dataset[0]['messages'][1]['content'][:300] + "...")

Raw dataset info:
Dataset({
    features: ['question_id', 'original_question', 'decomposed_questions'],
    num_rows: 2527
})

Column names: ['question_id', 'original_question', 'decomposed_questions']


Creating chat-formatted training data: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2527/2527 [00:00<00:00, 7806.22 examples/s]


Training dataset created!
Dataset({
    features: ['messages'],
    num_rows: 2527
})

Column names: ['messages']

First example messages (user prompt first 500 chars):
Your task is to break down a complex multi-hop question into the most efficient sequence of single-hop, **atomic** questions.

## Your Main Goal: Build Smart Bridges, Don't Just Collect Nouns
The most critical skill is to convert complex logical clauses (like "despite," "the country where," "the year before") into a single, powerful **bridging question**. This question should use a known entity as context to find the next one. Avoid finding all the entities separately and then trying to figure o...

Assistant response (first 300 chars):
{
    "questions": [
        {
            "question": "Who wrote the song \"Crazy Little Thing Called Love\"?",
            "requires_retrieval": true
        },
        {
            "question": "In what year did <ENTITY_Q1> die?",
            "requires_retrieval": true
        }
    




### Train/Eval Split (600/200)

We'll split the dataset stratified by question type to maintain the distribution across train and eval sets.

In [6]:
# Split dataset: 600 train / 200 eval (stratified by question type)
from collections import Counter
from datasets import ClassLabel

# Extract question types from IDs (e.g., "2hop", "3hop1", "4hop2")
def get_question_type(question_id):
    """Extract question type from ID for stratification."""
    return question_id.split("__")[0]

# Add question types to raw dataset
decomposition_list_with_types = [
    {**item, "question_type": get_question_type(item["question_id"])}
    for item in decomposition_list
]

# Count distribution
type_counts = Counter([item["question_type"] for item in decomposition_list_with_types])
print("Question type distribution:")
for qtype, count in sorted(type_counts.items()):
    print(f"  {qtype}: {count}")

# Create dataset with question types
full_dataset = Dataset.from_list(decomposition_list_with_types)

# Convert question_type to ClassLabel for stratification
unique_types = sorted(list(set([item["question_type"] for item in decomposition_list_with_types])))
full_dataset = full_dataset.cast_column(
    "question_type",
    ClassLabel(names=unique_types)
)

# Stratified split
split_dataset = full_dataset.train_test_split(
    test_size=500,
    train_size=2027,
    stratify_by_column="question_type",
    seed=42
)

train_raw = split_dataset["train"]
eval_raw = split_dataset["test"]

print(f"\n‚úì Split complete!")
print(f"  Training set: {len(train_raw)} examples")
print(f"  Evaluation set: {len(eval_raw)} examples")

# Verify stratification
train_types = Counter([unique_types[idx] for idx in train_raw["question_type"]])
eval_types = Counter([unique_types[idx] for idx in eval_raw["question_type"]])

print(f"\nTraining set distribution:")
for qtype in unique_types:
    print(f"  {qtype}: {train_types[qtype]}")

print(f"\nEvaluation set distribution:")
for qtype in unique_types:
    print(f"  {qtype}: {eval_types[qtype]}")

Question type distribution:
  2hop: 500
  3hop1: 500
  3hop2: 500
  4hop1: 500
  4hop2: 127
  4hop3: 400


Casting the dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2527/2527 [00:00<00:00, 284911.86 examples/s]



‚úì Split complete!
  Training set: 2027 examples
  Evaluation set: 500 examples

Training set distribution:
  2hop: 401
  3hop1: 401
  3hop2: 401
  4hop1: 401
  4hop2: 102
  4hop3: 321

Evaluation set distribution:
  2hop: 99
  3hop1: 99
  3hop2: 99
  4hop1: 99
  4hop2: 25
  4hop3: 79


In [7]:
# Apply chat formatting to both train and eval sets
training_dataset = train_raw.map(
    create_chat_messages,
    remove_columns=train_raw.column_names,
    desc="Creating chat-formatted training data"
)

eval_dataset = eval_raw.map(
    create_chat_messages,
    remove_columns=eval_raw.column_names,
    desc="Creating chat-formatted evaluation data"
)

print(f"‚úì Chat-formatted datasets ready:")
print(f"  Training: {len(training_dataset)} examples")
print(f"  Evaluation: {len(eval_dataset)} examples")

Creating chat-formatted training data: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2027/2027 [00:00<00:00, 9392.56 examples/s]
Creating chat-formatted evaluation data: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [00:00<00:00, 11230.93 examples/s]

‚úì Chat-formatted datasets ready:
  Training: 2027 examples
  Evaluation: 500 examples





### Apply Chat Template

Convert messages to text format using the tokenizer's chat template

In [8]:
# Apply chat template to convert to final training format
# This will be done automatically by SFTTrainer, but we test it here

print("Testing chat template formatting...")
sample_formatted = tokenizer.apply_chat_template(
    training_dataset[0]["messages"],
    tokenize=False,
    add_generation_prompt=False
)

print(f"\nFormatted sample (first 500 chars):")
print(sample_formatted[:500])
print("\n... [content truncated] ...")
print(f"\nLast 200 chars:")
print(sample_formatted[-200:])

print(f"\n‚úì Chat template formatting works!")
print(f"Training dataset ready with {len(training_dataset)} examples")

Testing chat template formatting...

Formatted sample (first 500 chars):
<|im_start|>user
Your task is to break down a complex multi-hop question into the most efficient sequence of single-hop, **atomic** questions.

## Your Main Goal: Build Smart Bridges, Don't Just Collect Nouns
The most critical skill is to convert complex logical clauses (like "despite," "the country where," "the year before") into a single, powerful **bridging question**. This question should use a known entity as context to find the next one. Avoid finding all the entities separately and then t

... [content truncated] ...

Last 200 chars:
           "requires_retrieval": true
        },
        {
            "question": "How many undergraduates does <ENTITY_Q3> have?",
            "requires_retrieval": true
        }
    ]
}<|im_end|>


‚úì Chat template formatting works!
Training dataset ready with 2027 examples


### LLM Judge for Evaluation

We'll use GPT-5 (with reasoning) to evaluate question decomposition quality every 3 training steps.

In [9]:
# LLM Judge Implementation
import json
import os
from openai import OpenAI
from typing import List, Dict, Any, Optional
from datetime import datetime

class QuestionDecompositionJudge:
    """GPT-5 based judge for evaluating question decomposition quality."""

    def __init__(self, model="gpt-5", temperature=0.0):
        self.model = model
        self.temperature = temperature
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

    def create_judge_prompt(self, original_question: str, decomposed_questions: List[Dict]) -> str:
        """Create evaluation prompt for the judge."""
        decomp_str = json.dumps({"questions": decomposed_questions}, indent=2)

        return f"""You are an expert evaluator of question decomposition quality for multi-hop QA systems.

**Original Question:** {original_question}

**Generated Decomposition:**
{decomp_str}

**Evaluation Criteria** (Score each decomposed question 1-5):

1. **Atomicity**: Is this a single-hop question retrieving only one piece of information?
2. **Bridge Building**: Proper use of <ENTITY_Qn> placeholders to reference previous answers?
3. **Efficiency**: Most direct path, avoiding over-decomposition?
4. **Correctness**: Logically sound and contributes to answering the original question?
5. **Retrieval Flag**: Is requires_retrieval set correctly?

**Response Format (JSON):**
{{
  "evaluations": [
    {{
      "question_index": 0,
      "question_text": "...",
      "scores": {{
        "atomicity": 5,
        "bridge_building": 5,
        "efficiency": 5,
        "correctness": 5,
        "retrieval_flag": 5
      }},
      "average": 5.0,
      "feedback": "Brief explanation"
    }}
  ],
  "overall_average": 5.0,
  "overall_feedback": "Brief summary"
}}"""

    def judge_decomposition(self, original_question: str, decomposed_questions: List[Dict]) -> Dict:
        """Evaluate a single decomposition."""
        prompt = self.create_judge_prompt(original_question, decomposed_questions)

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are an expert evaluator of question decomposition quality. Provide detailed, fair evaluations."},
                    {"role": "user", "content": prompt}
                ],
                temperature=self.temperature,
                response_format={"type": "json_object"}
            )

            result = json.loads(response.choices[0].message.content)
            result["original_question"] = original_question
            result["decomposed_questions"] = decomposed_questions

            return result
        except Exception as e:
            return {
                "error": str(e),
                "original_question": original_question,
                "evaluations": [],
                "overall_average": 0.0
            }

    def compute_aggregate_metrics(self, evaluation_results: List[Dict]) -> Dict[str, float]:
        """Compute aggregate metrics across evaluations."""
        valid_results = [r for r in evaluation_results if "error" not in r]

        if not valid_results:
            return {"error_rate": 1.0}

        overall_scores = [r["overall_average"] for r in valid_results]

        # Aggregate per-criterion scores
        all_scores = {
            "atomicity": [],
            "bridge_building": [],
            "efficiency": [],
            "correctness": [],
            "retrieval_flag": []
        }

        for result in valid_results:
            for eval_item in result.get("evaluations", []):
                scores = eval_item.get("scores", {})
                for criterion in all_scores:
                    if criterion in scores:
                        all_scores[criterion].append(scores[criterion])

        metrics = {
            "overall_average": sum(overall_scores) / len(overall_scores) if overall_scores else 0,
            "num_evaluated": len(valid_results),
            "error_rate": (len(evaluation_results) - len(valid_results)) / len(evaluation_results)
        }

        for criterion, values in all_scores.items():
            if values:
                metrics[f"{criterion}_avg"] = sum(values) / len(values)

        return metrics

print("‚úì QuestionDecompositionJudge class defined")

‚úì QuestionDecompositionJudge class defined


In [10]:
# Evaluation Callback - Runs during evaluation
from transformers import TrainerCallback
import random

class LLMJudgeEvaluationCallback(TrainerCallback):
    """
    Callback that evaluates model outputs using LLM judge during evaluation.
    """

    def __init__(
        self,
        eval_dataset,
        judge: QuestionDecompositionJudge,
        tokenizer,
        chat_template: str,
        num_samples: int = 200,
        logs_dir: str = "./judge_logs"
    ):
        self.eval_dataset = eval_dataset
        self.judge = judge
        self.tokenizer = tokenizer
        self.chat_template = chat_template
        self.num_samples = num_samples
        self.logs_dir = logs_dir

        # Create logs directory
        os.makedirs(logs_dir, exist_ok=True)

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Called when trainer runs evaluation."""
        current_step = state.global_step

        print(f"\n{'='*60}")
        print(f"üîç Running LLM Judge Evaluation at step {current_step}")
        print(f"{'='*60}")

        # Get the actual training model from kwargs
        model = kwargs.get('model')
        if model is None:
            print("‚ö†Ô∏è  Model not found in kwargs, skipping evaluation")
            return

        # Sample examples from eval set
        eval_indices = random.sample(range(len(self.eval_dataset)), min(self.num_samples, len(self.eval_dataset)))
        eval_samples = [self.eval_dataset[i] for i in eval_indices]

        # Model is already in eval mode by Trainer, no need to set it
        # Generate decompositions with the model
        evaluation_results = []

        for i, sample in enumerate(eval_samples):
            # Get the original question directly from the dataset
            # eval_raw has: question_id, original_question, decomposed_questions, question_type
            original_question = sample["original_question"]
            
            # Create the user prompt using the same format as training
            user_prompt = f"""Your task is to break down a complex multi-hop question into the most efficient sequence of single-hop, **atomic** questions.

## Your Main Goal: Build Smart Bridges, Don't Just Collect Nouns
The most critical skill is to convert complex logical clauses (like "despite," "the country where," "the year before") into a single, powerful **bridging question**. This question should use a known entity as context to find the next one. Avoid finding all the entities separately and then trying to figure out how they connect.

---
## A Simple Analogy for Efficiency

**Question:** "What is the phone number of the mother of the tallest player on the Lakers?"

** Inefficient Path:**
1.  Who are the players on the Lakers?
2.  What are all their heights?
3.  Who is the mother of the tallest player? *(This step is a logical leap)*

** Efficient Path:**
1.  Who is the tallest player on the Lakers?
2.  Who is the mother of `<ENTITY_Q1>`?
3.  What is the phone number of `<ENTITY_Q2>`?

---
## How to Decompose a Question
This process follows a logical flow from high-level analysis to the fine-tuning of your question chain.

### 1. Analyze the Query's Components
First, break down the original question into its fundamental building blocks. Identify the core **entities** (people, places, organizations), their **properties** (attributes like rank, location, date), and the **relationships** that connect them.

### 2. Construct an Atomic Chain
Next, formulate a sequence of questions where each question retrieves a single fact.
* **Isolate Comparisons:** Don't ask "who is faster?" Ask for the specific rank or time of each person involved.
* **Link with Placeholders:** Use `<ENTITY_Qn>` to pass the answer from a previous question (`Qn`) into the next one.

### 3. Optimize for Efficiency and Precision
Your final goal is the **shortest and most direct path** to the answer.
* **Embed Constraints to Build Bridges:** If a piece of information is only a filter (like a date or location), embed it as a constraint in the next question instead of asking for it directly.
**Important note for bridges:** There can be no `<ENTITY_Qn>` in the first question if the nth question DOES NOT require retrieval.

## Formatting
Format each decomposed question as follows:

Question: [the question text]
Requires retrieval: [true/false]

And provide the response in the following JSON format:
{{
  "questions": [
    {{
      "question": "the decomposed question text",
      "requires_retrieval": "true/false"
    }}
  ]
}}

Examples:

Input: "What is the birth year of the spouse of the director of Casablanca?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Casablanca?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who was <ENTITY_Q1>'s spouse?",
            "requires_retrieval": True
        }},
        {{
            "question": "What is <ENTITY_Q2>'s birth year?",
            "requires_retrieval": True
        }}
    ]
}}

Input: "Which film has the director who is older, Dune or The Dark Knight?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Dune?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who directed The Dark Knight?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": False
        }}
    ]
}}


IMPORTANT:
    AVOID over-decomposition like this:
    DON'T break "Who is John Doe?" into:
    1. Who is John Doe? ‚Üí "English"
    2. When was <ENTITY_Q1> born? ‚Üí "When was English born?"

    DO ask directly: "When was John Doe born?"

Now decompose this question:
Input: "{original_question}"
Output:
"""

            # Generate decomposition
            messages = [{"role": "user", "content": user_prompt}]
            text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                chat_template=self.chat_template
            )

            inputs = self.tokenizer(text, return_tensors="pt").to(model.device)

            with torch.no_grad():
                # Use BF16 autocast to match training environment's dtype handling
                with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=512,
                        temperature=0.1,
                        top_p=0.9,
                        do_sample=True,
                        use_cache=True,
                        pad_token_id=self.tokenizer.eos_token_id
                    )

            # Decode output
            generated_text = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)

            # Parse JSON from generated text
            try:
                # Extract JSON from the output
                json_start = generated_text.find("{")
                json_end = generated_text.rfind("}") + 1
                if json_start != -1 and json_end > json_start:
                    json_str = generated_text[json_start:json_end]
                    generated_decomp = json.loads(json_str)
                    decomposed_questions = generated_decomp.get("questions", [])
                else:
                    decomposed_questions = []
            except:
                decomposed_questions = []

            # Judge the decomposition
            if decomposed_questions:
                result = self.judge.judge_decomposition(original_question, decomposed_questions)
                evaluation_results.append(result)

            if (i + 1) % 5 == 0:
                print(f"  Evaluated {i + 1}/{len(eval_samples)} samples...")
                print(f"  {len(decomposed_questions)} were generated successfully.")

        # Compute aggregate metrics
        if evaluation_results:
            judge_metrics = self.judge.compute_aggregate_metrics(evaluation_results)

            print(f"\nüìä Evaluation Results (Step {current_step}):")
            print(f"  Overall Average: {judge_metrics.get('overall_average', 0):.2f}/5.0")
            print(f"  Atomicity: {judge_metrics.get('atomicity_avg', 0):.2f}/5.0")
            print(f"  Bridge Building: {judge_metrics.get('bridge_building_avg', 0):.2f}/5.0")
            print(f"  Efficiency: {judge_metrics.get('efficiency_avg', 0):.2f}/5.0")
            print(f"  Correctness: {judge_metrics.get('correctness_avg', 0):.2f}/5.0")
            print(f"  Retrieval Flag: {judge_metrics.get('retrieval_flag_avg', 0):.2f}/5.0")
            print(f"  Evaluated: {judge_metrics.get('num_evaluated', 0)} samples")

            # Log to WandB if available
            if state.is_world_process_zero:
                try:
                    import wandb
                    # Don't specify step - let WandB auto-increment
                    wandb.log({
                        "judge/overall_average": judge_metrics.get("overall_average", 0),
                        "judge/atomicity": judge_metrics.get("atomicity_avg", 0),
                        "judge/bridge_building": judge_metrics.get("bridge_building_avg", 0),
                        "judge/efficiency": judge_metrics.get("efficiency_avg", 0),
                        "judge/correctness": judge_metrics.get("correctness_avg", 0),
                        "judge/retrieval_flag": judge_metrics.get("retrieval_flag_avg", 0),
                        "judge/num_evaluated": judge_metrics.get("num_evaluated", 0),
                    })
                except:
                    pass

            # Save detailed results to JSON
            log_file = os.path.join(self.logs_dir, f"judge_eval_step_{current_step}.json")
            with open(log_file, "w") as f:
                json.dump({
                    "step": current_step,
                    "metrics": judge_metrics,
                    "detailed_results": evaluation_results
                }, f, indent=2)

            print(f"  üíæ Detailed results saved to: {log_file}")
        else:
            print(f"\n‚ö†Ô∏è  No valid evaluations generated at step {current_step}")

        print(f"{'='*60}\n")

In [11]:
from trl import SFTTrainer, SFTConfig
from typing import Union, List, Dict, Any

# Load the new chat template
new_chat_template = open('/home/yigit/codebase/gsw-memory/playground/question_decomp_local/qwen3_nonthinking.jinja').read()

def formatting_function(example):
    """
    Unsloth requires List[str].
    - If 'example["messages"]' is a list of message-lists (batched), return one
      formatted string per item.
    - If it's a single list of messages (single sample), return a 1-element list.
    """
    msgs = example["messages"]

    # Batched: msgs is like [ [ {role, content}, ... ], [ {role, content}, ... ], ... ]
    if isinstance(msgs, list) and msgs and isinstance(msgs[0], list):
        texts = [
            tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=False, chat_template=new_chat_template)
            for m in msgs
        ]
    else:
        # Single example
        texts = [
            tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=False, chat_template=new_chat_template
            )
        ]

    # Ensure flat list[str] with no empties
    return [t for t in texts if isinstance(t, str) and t.strip()]


# Create SFTConfig (formatting_func goes to SFTTrainer, not here)
training_args = SFTConfig(
    output_dir="./qwen3_0.6b_question_decomp_full_ft",  # Required parameter
    per_device_train_batch_size = 8,  # Can increase for 0.6B model
    gradient_accumulation_steps = 8,   # Use GA to mimic batch size!
    warmup_steps = 5,
    # num_train_epochs = 1,            # Set this for 1 full training run.
    max_steps = 200,                    # For quick testing
    max_length=2048,
    learning_rate = 5e-6,              # Lower LR for full finetuning (was 2e-4 for LoRA)
    logging_steps = 1,
    optim = "adamw_8bit",
    weight_decay = 0.01,               # Increased weight decay for full finetuning
    lr_scheduler_type = "linear",
    seed = 3407,
    bf16 = True,                       # Use BF16 precision for full finetuning
    # fp16 = True,                         # Use FP16 precision for full finetuning on T4
    gradient_checkpointing = True,     # Enable to save memory
    report_to = "wandb",                # Use TrackIO/WandB etc
    # Evaluation configuration - runs LLM judge every 3 steps
    eval_strategy="steps",             # Evaluate at regular step intervals
    eval_steps=20,                      # Run evaluation every 3 training steps
)

# Initialize LLM Judge
judge = QuestionDecompositionJudge(
    model="gpt-4o",
    temperature=0.0
)

# Create evaluation callback (uses on_evaluate, triggered by eval_strategy)
# Pass eval_raw instead of eval_dataset to get original_question field
judge_callback = LLMJudgeEvaluationCallback(
    eval_dataset=eval_raw,              # Use raw dataset with original_question field
    judge=judge,
    tokenizer=tokenizer,
    chat_template=new_chat_template,
    num_samples=200,         # Evaluate 20 samples each time
    logs_dir="./judge_logs"
)

# Create trainer with eval dataset and callback
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = training_dataset,  # 600 examples
    eval_dataset = eval_dataset,        # 200 examples (used for standard eval + judge)
    formatting_func = formatting_function,
    args = training_args,
    # callbacks = [judge_callback],      # Add LLM judge callback
)

print(f"‚úì Trainer configured with:")
print(f"  - Training examples: {len(training_dataset)}")
print(f"  - Evaluation examples: {len(eval_dataset)}")
print(f"  - Evaluation strategy: every {training_args.eval_steps} steps")
print(f"  - LLM judge samples per evaluation: {judge_callback.num_samples}")

Unsloth: Tokenizing ["text"] (num_proc=64): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2027/2027 [00:12<00:00, 156.44 examples/s]
Unsloth: Tokenizing ["text"] (num_proc=64): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [00:13<00:00, 37.48 examples/s]

‚úì Trainer configured with:
  - Training examples: 2027
  - Evaluation examples: 500
  - Evaluation strategy: every 20 steps
  - LLM judge samples per evaluation: 200





<a name="Train"></a>
### Train the model with Full Finetuning

Now let's train our model using **full finetuning** (all parameters are updated, not just LoRA adapters).

We do 30 steps for quick testing, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.

**Important notes for full finetuning:**
- Learning rate is much lower (5e-6 vs 2e-4 for LoRA)
- Training will take longer as all parameters are updated
- Memory usage is higher, but gradient checkpointing helps
- The model quality can be better than LoRA since all parameters are trained

In [12]:
# from trl import SFTTrainer, SFTConfig
# from typing import Union, List, Dict, Any

# # Load the new chat template
# new_chat_template = open('/home/yigit/codebase/gsw-memory/playground/question_decomp_local/qwen3_nonthinking.jinja').read()

# def formatting_function(example):
#     """
#     Unsloth requires List[str].
#     - If 'example["messages"]' is a list of message-lists (batched), return one
#       formatted string per item.
#     - If it's a single list of messages (single sample), return a 1-element list.
#     """
#     msgs = example["messages"]

#     # Batched: msgs is like [ [ {role, content}, ... ], [ {role, content}, ... ], ... ]
#     if isinstance(msgs, list) and msgs and isinstance(msgs[0], list):
#         texts = [
#             tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=False, chat_template=new_chat_template)
#             for m in msgs
#         ]
#     else:
#         # Single example
#         texts = [
#             tokenizer.apply_chat_template(
#                 msgs, tokenize=False, add_generation_prompt=False, chat_template=new_chat_template
#             )
#         ]

#     # Ensure flat list[str] with no empties
#     return [t for t in texts if isinstance(t, str) and t.strip()]


# # Create SFTConfig (formatting_func goes to SFTTrainer, not here)
# training_args = SFTConfig(
#     output_dir="./qwen3_0.6b_question_decomp_full_ft",  # Required parameter
#     per_device_train_batch_size = 2,  # Can increase for 0.6B model
#     gradient_accumulation_steps = 4,   # Use GA to mimic batch size!
#     warmup_steps = 5,
#     # num_train_epochs = 1,            # Set this for 1 full training run.
#     max_steps = 200,                    # For quick testing
#     max_length=2048,
#     learning_rate = 5e-6,              # Lower LR for full finetuning (was 2e-4 for LoRA)
#     logging_steps = 1,
#     optim = "adamw_8bit",
#     weight_decay = 0.01,               # Increased weight decay for full finetuning
#     lr_scheduler_type = "linear",
#     seed = 3407,
#     # bf16 = True,                       # Use BF16 precision for full finetuning
#     gradient_checkpointing = True,     # Enable to save memory
#     report_to = "wandb",                # Use TrackIO/WandB etc
# )

# # Create trainer - formatting_func is passed to SFTTrainer, not SFTConfig
# # Note: We removed the custom data_collator that was causing the "no loss" error
# # SFTTrainer will use its default DataCollatorForLanguageModeling which properly handles labels
# trainer = SFTTrainer(
#     model = model,
#     tokenizer = tokenizer,
#     train_dataset = training_dataset,  # Using the question decomposition dataset
#     eval_dataset = None, # Can set up evaluation!
#     formatting_func = formatting_function,  # Pass to SFTTrainer in newer trl versions
#     args = training_args,
# )

In [13]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA RTX A6000. Max memory = 47.402 GB.
1.135 GB of memory reserved.


In [14]:
# clear gpu
import gc
gc.collect()
torch.cuda.empty_cache()

Let's train the model! To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

In [15]:
trainer_stats = trainer.train()

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,027 | Num Epochs = 7 | Total steps = 200
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 596,049,920 of 596,049,920 (100.00% trained)
[34m[1mwandb[0m: Currently logged in as: [33mmyigitturali[0m ([33mmyigitturali-UCLA[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Weave is installed but not imported. Add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
20,1.5468,1.517792
40,1.0358,1.012272
60,0.5141,0.490573
80,0.241,0.233057
100,0.1746,0.168657
120,0.1485,0.150111
140,0.1459,0.145187
160,0.1446,0.14333
180,0.1588,0.142778
200,0.1377,0.142741


Unsloth: Not an error, but Qwen3Model does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


In [16]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

2109.1378 seconds used for training.
35.15 minutes used for training.
Peak reserved memory = 10.734 GB.
Peak reserved memory for training = 9.599 GB.
Peak reserved memory % of max memory = 22.645 %.
Peak reserved memory for training % of max memory = 20.25 %.


<a name="Inference"></a>
### Inference - Question Decomposition

Let's test the model on question decomposition! The model should break down complex multi-hop questions into atomic single-hop questions.

The model will output JSON format with:
- `questions`: List of decomposed sub-questions
- `requires_retrieval`: Boolean flag for each question indicating if it needs retrieval

For structured JSON output, use lower temperature (0.0-0.3) for consistency.

In [17]:
import torch
from transformers import TextStreamer

# Load the new chat template for inference
new_chat_template = open('/home/yigit/codebase/gsw-memory/playground/question_decomp_local/qwen3_nonthinking.jinja').read()

# Example 1: Test question decomposition with a 2-hop question
test_question = "Where did Lothair II's mother die?"

# Create the prompt using the same format as training
user_prompt = f"""Your task is to break down a complex multi-hop question into the most efficient sequence of single-hop, **atomic** questions.

## Your Main Goal: Build Smart Bridges, Don't Just Collect Nouns
The most critical skill is to convert complex logical clauses (like "despite," "the country where," "the year before") into a single, powerful **bridging question**. This question should use a known entity as context to find the next one. Avoid finding all the entities separately and then trying to figure out how they connect.

---
## A Simple Analogy for Efficiency

**Question:** "What is the phone number of the mother of the tallest player on the Lakers?"

** Inefficient Path:**
1.  Who are the players on the Lakers?
2.  What are all their heights?
3.  Who is the mother of the tallest player? *(This step is a logical leap)*

** Efficient Path:**
1.  Who is the tallest player on the Lakers?
2.  Who is the mother of `<ENTITY_Q1>`?
3.  What is the phone number of `<ENTITY_Q2>`?

---
## How to Decompose a Question
This process follows a logical flow from high-level analysis to the fine-tuning of your question chain.

### 1. Analyze the Query's Components
First, break down the original question into its fundamental building blocks. Identify the core **entities** (people, places, organizations), their **properties** (attributes like rank, location, date), and the **relationships** that connect them.

### 2. Construct an Atomic Chain
Next, formulate a sequence of questions where each question retrieves a single fact.
* **Isolate Comparisons:** Don't ask "who is faster?" Ask for the specific rank or time of each person involved.
* **Link with Placeholders:** Use `<ENTITY_Qn>` to pass the answer from a previous question (`Qn`) into the next one.

### 3. Optimize for Efficiency and Precision
Your final goal is the **shortest and most direct path** to the answer.
* **Embed Constraints to Build Bridges:** If a piece of information is only a filter (like a date or location), embed it as a constraint in the next question instead of asking for it directly.
**Important note for bridges:** There can be no `<ENTITY_Qn>` in the first question if the nth question DOES NOT require retrieval.

## Formatting
Format each decomposed question as follows:

Question: [the question text]
Requires retrieval: [true/false]

And provide the response in the following JSON format:
{{
  "questions": [
    {{
      "question": "the decomposed question text",
      "requires_retrieval": "True/False"
    }}
  ]
}}

Examples:

Input: "What is the birth year of the spouse of the director of Casablanca?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Casablanca?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who was <ENTITY_Q1>'s spouse?",
            "requires_retrieval": True
        }},
        {{
            "question": "What is <ENTITY_Q2>'s birth year?",
            "requires_retrieval": True
        }}
    ]
}}

Input: "Which film has the director who is older, Dune or The Dark Knight?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Dune?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who directed The Dark Knight?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": False
        }}
    ]
}}


IMPORTANT:
    AVOID over-decomposition like this:
    DON'T break "Who is John Doe?" into:
    1. Who is John Doe? ‚Üí "English"
    2. When was <ENTITY_Q1> born? ‚Üí "When was English born?"

    DO ask directly: "When was John Doe born?"

Now decompose this question:
Input: "{test_question}"
Output:
"""

messages = [
    {"role": "user", "content": user_prompt}
]

text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    chat_template=new_chat_template # Pass the new template here
)


# Prepare inputs and cast to float16 for consistency during inference
inputs = tokenizer(text, return_tensors="pt").to("cuda")

print(f"Question: {test_question}\n")
print("Decomposition:")

# Ensure model is in evaluation mode and cast to float16
model.eval()

with torch.no_grad():
    # Explicitly disable autocast, as we are forcing float16 already.
    # This prevents any potential re-casting issues if autocast implicitly tries to use float32 for some ops.
    with torch.amp.autocast('cuda', enabled=False): # Updated syntax
        _ = model.generate(
            **inputs,
            max_new_tokens=512,  # Enough for JSON output
            temperature=0.1,     # Low temperature for consistent JSON
            top_p=0.9,
            streamer=TextStreamer(tokenizer, skip_prompt=True),
        )


Question: Where did Lothair II's mother die?

Decomposition:
{
    "questions": [
        {
            "question": "Who was Lothair II's mother?",
            "requires_retrieval": true
        },
        {
            "question": "Where did <ENTITY_Q1> die?",
            "requires_retrieval": true
        }
    ]
}<|im_end|>


In [18]:
import torch
from transformers import TextStreamer

# Load the new chat template for inference
new_chat_template = open('/home/yigit/codebase/gsw-memory/playground/question_decomp_local/qwen3_nonthinking.jinja').read()

# Example 2: Test with a more complex 3-hop question
test_question_2 = "What is the birth year of the spouse of the director of Casablanca?"

user_prompt_2 = f"""Your task is to break down a complex multi-hop question into the most efficient sequence of single-hop, **atomic** questions.

## Your Main Goal: Build Smart Bridges, Don't Just Collect Nouns
The most critical skill is to convert complex logical clauses (like "despite," "the country where," "the year before") into a single, powerful **bridging question**. This question should use a known entity as context to find the next one. Avoid finding all the entities separately and then trying to figure out how they connect.

---
## A Simple Analogy for Efficiency

**Question:** "What is the phone number of the mother of the tallest player on the Lakers?"

** Inefficient Path:**
1.  Who are the players on the Lakers?
2.  What are all their heights?
3.  Who is the mother of the tallest player? *(This step is a logical leap)*

** Efficient Path:**
1.  Who is the tallest player on the Lakers?
2.  Who is the mother of `<ENTITY_Q1>`?
3.  What is the phone number of `<ENTITY_Q2>`?

---
## How to Decompose a Question
This process follows a logical flow from high-level analysis to the fine-tuning of your question chain.

### 1. Analyze the Query's Components
First, break down the original question into its fundamental building blocks. Identify the core **entities** (people, places, organizations), their **properties** (attributes like rank, location, date), and the **relationships** that connect them.

### 2. Construct an Atomic Chain
Next, formulate a sequence of questions where each question retrieves a single fact.
* **Isolate Comparisons:** Don't ask "who is faster?" Ask for the specific rank or time of each person involved.
* **Link with Placeholders:** Use `<ENTITY_Qn>` to pass the answer from a previous question (`Qn`) into the next one.

### 3. Optimize for Efficiency and Precision
Your final goal is the **shortest and most direct path** to the answer.
* **Embed Constraints to Build Bridges:** If a piece of information is only a filter (like a date or location), embed it as a constraint in the next question instead of asking for it directly.
**Important note for bridges:** There can be no `<ENTITY_Qn>` in the first question if the nth question DOES NOT require retrieval.

## Formatting
Format each decomposed question as follows:

Question: [the question text]
Requires retrieval: [true/false]

And provide the response in the following JSON format:
{{
  "questions": [
    {{
      "question": "the decomposed question text",
      "requires_retrieval": "true/false"
    }}
  ]
}}

Examples:

Input: "What is the birth year of the spouse of the director of Casablanca?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Casablanca?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who was <ENTITY_Q1>'s spouse?",
            "requires_retrieval": True
        }},
        {{
            "question": "What is <ENTITY_Q2>'s birth year?",
            "requires_retrieval": True
        }}
    ]
}}

Input: "Which film has the director who is older, Dune or The Dark Knight?"
Output:
{{
    "questions": [
        {{
            "question": "Who directed Dune?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who directed The Dark Knight?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": True
        }},
        {{
            "question": "Who is older, <ENTITY_Q1> or <ENTITY_Q2>?",
            "requires_retrieval": False
        }}
    ]
}}


IMPORTANT:
    AVOID over-decomposition like this:
    DON'T break "Who is John Doe?" into:
    1. Who is John Doe? ‚Üí "English"
    2. When was <ENTITY_Q1> born? ‚Üí "When was English born?"

    DO ask directly: "When was John Doe born?"

Now decompose this question:
Input: "{test_question_2}"
Output:
"""

messages_2 = [
    {"role": "user", "content": user_prompt_2}
]

text_2 = tokenizer.apply_chat_template(
    messages_2,
    tokenize=False,
    add_generation_prompt=True,
    chat_template=new_chat_template # Pass the new template here
)

# Prepare inputs and cast to float16 for consistency during inference
inputs_2 = tokenizer(text_2, return_tensors="pt").to("cuda")

print(f"\nQuestion: {test_question_2}\n")
print("Decomposition:")

# Ensure model is in evaluation mode and cast to float16
model.eval()

with torch.no_grad():
    # Explicitly disable autocast, as we are forcing float16 already.
    # This prevents any potential re-casting issues if autocast implicitly tries to use float32 for some ops.
    with torch.amp.autocast('cuda', enabled=False): # Updated syntax
        _ = model.generate(
            **inputs_2,
            max_new_tokens=512,
            temperature=0.1,
            top_p=0.9,
            streamer=TextStreamer(tokenizer, skip_prompt=True),
        )



Question: What is the birth year of the spouse of the director of Casablanca?

Decomposition:
{
    "questions": [
        {
            "question": "Who directed Casablanca?",
            "requires_retrieval": true
        },
        {
            "question": "Who was <ENTITY_Q1>'s spouse?",
            "requires_retrieval": true
        },
        {
            "question": "What is <ENTITY_Q2>'s birth year?",
            "requires_retrieval": true
        }
    ]
}<|im_end|>


### Saving the fully finetuned model
For full finetuning, we save the **entire model** (not just adapters like LoRA). You can use Huggingface's `push_to_hub` for online save or `save_pretrained` for local save.

**[NOTE]** This saves the complete model with all trained parameters.

In [19]:
import os
os.environ["HF_TOKEN"] = "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"

model.push_to_hub("yigitturali/qwen3-0.6b-gsw-q-decomp-finetuned-large")
tokenizer.push_to_hub("yigitturali/qwen3-0.6b-gsw-q-decomp-finetuned-large")

Processing Files (1 / 1): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.19GB / 1.19GB, 83.8MB/s  
New Data Upload: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.19GB / 1.19GB, 83.8MB/s  


Saved model to https://huggingface.co/yigitturali/qwen3-0.6b-gsw-q-decomp-finetuned-large


Processing Files (1 / 1): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11.4MB / 11.4MB,  0.00B/s  
New Data Upload: |          |  0.00B /  0.00B,  0.00B/s  
No files have been modified since last commit. Skipping to prevent empty commit.


Now if you want to load the fully finetuned model we just saved for inference, set `False` to `True`:

In [20]:
if False:
    from unsloth import FastLanguageModel
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "full_finetuned_model", # Path to your saved model
        max_seq_length = 2048,
        load_in_4bit = True,  # Can use 4bit for inference to save memory
    )

### Saving to float16 for VLLM

For full finetuning, you can save the model in different formats. Select `merged_16bit` for float16 or `merged_4bit` for int4. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

**Note:** For full finetuning, the model is already "merged" (no adapters to merge), so these methods will save the complete model in the specified format.

In [21]:
# Save full model to 16bit
if False:
    model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: # Pushing to HF Hub
    model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Save full model to 4bit
if False:
    model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: # Pushing to HF Hub
    model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Save full model in standard format
if False:
    model.save_pretrained("model")
    tokenizer.save_pretrained("model")
if False: # Pushing to HF Hub
    model.push_to_hub("hf/model", token = "")
    tokenizer.push_to_hub("hf/model", token = "")


### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

For full finetuning, the entire model will be converted to GGUF format.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)

In [None]:
# Save to 8bit Q8_0
if False:
    model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False:
    model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False:
    model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: # Pushing to HF Hub
    model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False:
    model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: # Pushing to HF Hub
    model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "", # Get a token at https://huggingface.co/settings/tokens
    )

: 