# Necessary Imports

In [40]:
from openai import OpenAI
import json
import os
import random
import sys
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from tqdm import tqdm
from bespokelabs import curator

# Fix sys.path and import paths for notebook execution
sys.path.append("/home/yigit/codebase/gsw-memory/src/gsw_memory/")
from gsw_memory.memory.models import GSWStructure
from gsw_memory.memory.operator_utils import parse_gsw, GSWOperator
from gsw_memory.prompts.operator_prompts import PromptType, FactualExtractionPrompts
from openai.lib._parsing._completions import type_to_response_format_param

# Test

In [3]:
client = OpenAI(base_url="http://127.0.0.1:6380/v1", api_key="token-abc123")
completion = client.chat.completions.create(
  model="Qwen/Qwen3-8B",
  messages=[{"role": "user", "content": "Give me a short introduction to large language models."}],
  temperature=0.6,
    extra_body={
                "temperature": 0.6, 
                "top_p": 0.95, 
                "top_k": 20, 
                "min_p": 0, 
                "max_tokens": 4096, 
                "repetition_penalty": 1.1,
                "presence_penalty": 0.3,
                "frequency_penalty": 0.3
                            }
)

In [4]:
completion

ChatCompletion(id='chatcmpl-4e6fd883ca15494096f5aea88fec095e', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='\n\nLarge language models (LLMs) are advanced artificial intelligence systems designed to understand and generate human-like text based on vast amounts of training data. By analyzing patterns in language, they can perform tasks like answering questions, creating content, translating languages, and even coding. These models rely on complex neural network architectures (e.g., transformers) and require extensive computational resources to train. While they excel at mimicking human communication, challenges remain—including addressing biases in training data and ensuring ethical use. LLMs have revolutionized fields like customer service, education, and creative writing by enabling more intuitive interactions with technology.', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reason

# Load Musique Test and Train Data

In [41]:
test_musique = json.load(open("/home/yigit/codebase/gsw-memory/playground_data/musique.json"))
# Load Musique Train Data jsonl to json format
train_musique = [json.loads(line) for line in open("/home/yigit/codebase/gsw-memory/playground_data/musique_full_v1.0_train.jsonl")]

In [42]:

train_musique_corpus = {}
for data in train_musique:
    paragraphs = data["paragraphs"]
    for doc_idx, paragraph in enumerate(paragraphs):
        train_musique_corpus[str(data["id"]) + "_" + str(paragraph["idx"])] = {
            "global_id": f"{data['id']}_{paragraph['idx']}",
            "title": paragraph["title"],
            "text": paragraph["title"] + "\n" + paragraph["paragraph_text"],
            "id": data["id"],
            "idx": paragraph["idx"]
        }
        
        
test_musique_corpus = {}
for data in test_musique:
    paragraphs = data["paragraphs"]
    for doc_idx, paragraph in enumerate(paragraphs):
        test_musique_corpus[str(data["id"]) + "_" + str(paragraph["idx"])] = {
            "global_id": f"{data['id']}_{paragraph['idx']}",
            "title": paragraph["title"],
            "text": paragraph["title"] + "\n" + paragraph["paragraph_text"],
            "id": data["id"],
            "idx": paragraph["idx"]
        }
        

In [43]:
# randomly sample 1000 documents from train_musique_corpus
train_musique_corpus_sample = random.sample(list(train_musique_corpus.values()), 100)
test_musique_corpus_sample = list(test_musique_corpus.values())[:250]


# Process Documents with Qwen3-8B

In [22]:
gsw_completions = []
SYSTEM_PROMPT = FactualExtractionPrompts.SYSTEM_PROMPT
USER_PROMPT_TEMPLATE = FactualExtractionPrompts.USER_PROMPT_TEMPLATE
USER_PROMPT = USER_PROMPT_TEMPLATE.format(input_text=train_musique_corpus_sample[1], background_context="")
client = OpenAI(base_url="http://127.0.0.1:6380/v1", api_key="token-abc123")
gsw_completions.append(client.chat.completions.parse(
    model="Qwen/Qwen3-8B",
    messages=[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": USER_PROMPT}],
    temperature=0.6,
    response_format=GSWStructure,
    extra_body={
                    # "temperature": 0.1,
                    "top_p": 0.9,
                    "top_k": 20,
                    "min_p": 0.0,
                    "max_tokens": 4096 * 3, 
                    "repetition_penalty": 1.1,
                    "presence_penalty": 0.3,
                    "frequency_penalty": 0.3
                            }
))

# Test On Curator

In [35]:
os.environ["HOSTED_VLLM_API_KEY"] = "token-abc123"
gsw_model = GSWOperator(
                model_name="hosted_vllm/Qwen/Qwen3-8B",
                backend_params = {
                    "base_url": "http://127.0.0.1:6380/v1",
                    "request_timeout": 600.0,  
                    "max_concurrent_requests": 64,
                    "max_requests_per_minute": 120,
                    "max_tokens_per_minute": 200000,
                    "seconds_to_pause_on_rate_limit": 5,
                    "require_all_responses": False,
                },
                generation_params={
                    "temperature": 0.6, 
                    "top_p": 0.95, 
                    "top_k": 20, 
                    "min_p": 0, 
                    "max_tokens": 4096 * 3, 
                    "repetition_penalty": 1.1,
                    "presence_penalty": 0.3,
                    "frequency_penalty": 0.3
                            },
                prompt_type=PromptType.FACTUAL,
                backend="litellm",
                response_format=GSWStructure,  # Use constrained decoding
                batch=False,
            )


  PydanticSerializationUnexpectedValue(Expected 6 fields but got 5: Expected `Message` - serialized value may not be as expected [field_name='choices', input_value=Message(content='\n\nHell...urther conversation.\n'), input_type=Message])
  PydanticSerializationUnexpectedValue(Expected `StreamingChoices` - serialized value may not be as expected [field_name='choices', input_value=Choices(finish_reason='st...rther conversation.\n')), input_type=Choices])
  return self.__pydantic_serializer__.to_python(


In [36]:
# os.environ["CURATOR_DISABLE_CACHE"] = "1"
gsw_response = gsw_model([train_musique_corpus_sample[0]])

Generating train split: 0 examples [00:00, ? examples/s]

Output()

In [39]:
all_documents_data = {}
for response in gsw_response.dataset: #TODO: Check if this is correct
    try:
        if response["gsw"]:
            gsw_dict = response["gsw"]
            gsw = GSWStructure(**gsw_dict)
        else:
            gsw = parse_gsw(response["graph"])
        doc_idx = response["doc_idx"]
        global_id = response["global_id"]
        all_documents_data[global_id] = gsw
    except Exception as e:
        print(
            f"Error parsing GSW for chunk {response.get('global_id', 'unknown')}: {e}"
        )
        continue
pred_gsws = list(all_documents_data.values())

In [None]:
# load first 100 documents from musique processed docs
import glob
musique_network_dir = '/mnt/SSD1/shreyas/SM_GSW/musique/networks_4_1_mini'
import re

def sort_natural_key(s):
    # Extract the numeric part from the path string
    match = re.search(r'doc_(\d+)', s)
    return int(match.group(1)) if match else s

musique_docs = sorted(
    glob.glob(f"{musique_network_dir}/doc_*"),
    key=sort_natural_key
)
musique_docs_0_99 = musique_docs[:100]

# go in each dir in musique_docs_0_99 and load each json file
import os

golden_gsws = []
for doc_dir in musique_docs_0_99:
    if os.path.isdir(doc_dir):
        json_files = sorted(glob.glob(os.path.join(doc_dir, "*.json")))
        for json_file in json_files:
            try:
                with open(json_file, 'r') as f:
                    doc_data = json.load(f)
                    doc_data = GSWStructure(**doc_data)
                    golden_gsws.append(doc_data)
            except Exception as e:
                print(f"Error loading {json_file}: {e}")
    else:
        print(f"Warning: {doc_dir} is not a directory, skipping.")



In [None]:
LLM_AS_A_JUDGE_PROMPT ="""SYSTEM:
You are a strict evaluator (LLM-as-a-judge) for GSW factual extraction graphs used in multi-hop QA.
You will compare a PREDICTED GSW against a GOLD GSW for the SAME input text.

Your job:
1) Determine how well the predicted GSW matches the gold GSW in factual coverage and structural compliance.
2) Identify missing facts, hallucinated facts, malformed entities, malformed relations, and question-format violations.
3) Output ONLY valid JSON in the schema given below. No extra commentary.

Definitions:
- A "fact" is a (subject entity, verb phrase, object entity) triple expressed by a verb_phrase_node whose questions imply that relation.
- "Coverage" means the predicted GSW contains an equivalent fact to one in the gold GSW.
- "Equivalent" means same underlying meaning, allowing:
  - minor wording differences in the verb phrase/questions,
  - re-ordering,
  - synonyms that do not change meaning (e.g., "joined" vs "enrolled" if clearly same in context),
  - but NOT allowing changed entities, times, locations, roles, or added details not in text.
- IDs may differ across graphs; match by entity name + type/role + context.

Hard constraints from the extraction spec (judge these strictly):
A) No fabrication: predicted facts must be entailed by the input text.
B) Atomic entities: do not bundle multiple separable entities into one.
C) Abbreviation/alias: if text has "Full Name (ABBR)", predicted should create:
   - entity "Full Name (ABBR)" and entity "ABBR" (alias) connected via "also known as".
   - Do NOT expand abbreviations that were not expanded in text.
D) Two questions per verb phrase node: exactly 2 questions per verb_phrase_node.
E) Each question must have exactly one unknown (answer) and no pronouns ("he", "it", "they", etc.).
F) Questions must contain complete content (no dropping "that ..." clauses when required).
G) Temporal connectivity: when gold connects a date/time via a phrase like "on/in/during", predicted should also connect that temporal info somewhere relevant.
H) Answers must be entity IDs only and must exist in entity_nodes.
I) Do not merge multiple subjects/objects into one question unless gold does (and the rule permits it).

Evaluation procedure:
1) Parse both GSWs into:
   - entity inventory: (name, role(s), states)
   - fact inventory: for each verb_phrase_node, infer intended (S, P, O) from the questions:
     * If one question is "Who <P> <O> ...?" and the answer is X => (X, P, O)
     * If one question is "What/Which <O> did <S> <P> ...?" and the answer is Y => (S, P, Y)
   If inference is ambiguous, record it as "unscorable_relation" and penalize format/clarity, not factual mismatch.
2) Align entities between gold and pred by best match on:
   - exact name match > normalized match (case/diacritics) > alias match
   - role compatibility (person vs org vs date etc.)
3) Align facts:
   - A gold fact is "covered" if an equivalent pred fact exists.
   - A pred fact is "hallucinated" if it is not entailed by input text OR has no corresponding gold fact meaningfully.
4) Score:
   - Coverage (0–1): covered_gold_facts / total_gold_facts
   - Precision (0–1): correct_pred_facts / total_pred_facts (exclude unscorable if clearly malformed)
   - Format compliance (0–1): start from 1 and subtract for each violation category (see below).
   - Overall (0–100): weighted:
       overall = 45*coverage + 35*precision + 20*format_compliance
   Provide also pass/fail for "usable_for_QA" with threshold overall>=80 and no critical violations.
Critical violations (any => usable_for_QA=false even if score high):
   - fabrication (A)
   - pronouns in questions (E)
   - answers not IDs / missing IDs (H)
   - not exactly two questions per verb phrase node (D)

Format compliance penalties (examples):
- two_questions_violation: -0.10 each verb phrase node violating
- pronoun_violation: -0.15 each occurrence
- missing_that_content: -0.10 each occurrence
- entity_bundling: -0.10 each bundled entity
- alias_rule_violation: -0.10 each missed/incorrect alias case
- temporal_disconnection: -0.05 each missing required time link
- bad_answer_ids: critical (set compliance to 0 for that node and mark critical)

Output JSON schema (MUST follow exactly):
{{
  "overall_score": 0-100 number,
  "usable_for_QA": boolean,
  "subscores": {{
    "coverage": 0-1 number,
    "precision": 0-1 number,
    "format_compliance": 0-1 number
  }},
  "counts": {{
    "gold_entities": int,
    "pred_entities": int,
    "gold_facts": int,
    "pred_facts": int,
    "covered_gold_facts": int,
    "hallucinated_pred_facts": int,
    "unscorable_pred_relations": int
  }},
  "critical_violations": [
    {{
      "type": "fabrication|pronouns|bad_answer_ids|two_questions_violation",
      "message": "short explanation",
      "location": "entity_id or verb_phrase_id or question_id"
    }}
  ],
  "entity_alignment": [
    {{
      "gold_entity_id": "e#",
      "pred_entity_id": "e# or null",
      "match_type": "exact|normalized|alias|none",
      "notes": "brief"
    }}
  ],
  "missing_entities": [
    {{ "gold_entity_id": "e#", "gold_name": "...", "reason": "missing|role_mismatch|bundled" }}
  ],
  "extra_entities": [
    {{ "pred_entity_id": "e#", "pred_name": "...", "reason": "hallucinated|unnecessary|bundled" }}
  ],
  "fact_comparison": {{
    "missing_facts": [
      {{
        "gold_fact": {{ "subject": "...", "predicate": "...", "object": "..."}},
        "gold_verb_phrase_id": "v#",
        "reason": "missing|temporal_missing|wrong_object|wrong_subject"
      }}
    ],
    "covered_facts": [
      {{
        "gold_fact": {{ "subject": "...", "predicate": "...", "object": "..."}},
        "pred_fact": {{ "subject": "...", "predicate": "...", "object": "..."}},
        "gold_verb_phrase_id": "v#",
        "pred_verb_phrase_id": "v#",
        "match_notes": "brief"
      }}
    ],
    "hallucinated_facts": [
      {{
        "pred_fact": {{ "subject": "...", "predicate": "...", "object": "..."}},
        "pred_verb_phrase_id": "v#",
        "reason": "not_in_gold|not_entailed_by_text|over_specific"
      }}
    ]
  }},
  "format_issues": [
    {{
      "type": "two_questions_violation|pronoun_violation|missing_that_content|alias_rule_violation|entity_bundling|temporal_disconnection|question_unknown_count|other",
      "message": "short explanation",
      "location": "v# or q# or e#"
    }}
  ],
  "improvement_suggestions": [
    "bullet-like string suggestions, concrete and minimal"
  ]
}}

USER (template you will receive):
<input_text>
{input_text}
</input_text>

<gold_gsw_json>
{gold_gsw_json}
</gold_gsw_json>

<pred_gsw_json>
{pred_gsw_json}
</pred_gsw_json>

Now perform the evaluation and output ONLY the JSON object.
"""


class Subscores(BaseModel):
    """Subscores for coverage, precision, and format compliance."""
    coverage: float = Field(description="Coverage score (0-1)")
    precision: float = Field(description="Precision score (0-1)")
    format_compliance: float = Field(description="Format compliance score (0-1)")


class Counts(BaseModel):
    """Counts of gold and predicted entities, facts, and violations."""
    gold_entities: int = Field(description="Number of gold entities")
    pred_entities: int = Field(description="Number of predicted entities")
    gold_facts: int = Field(description="Number of gold facts")
    pred_facts: int = Field(description="Number of predicted facts")
    covered_gold_facts: int = Field(description="Number of covered gold facts")
    hallucinated_pred_facts: int = Field(description="Number of hallucinated predicted facts")
    unscorable_pred_relations: int = Field(description="Number of unscorable predicted relations")


class CriticalViolation(BaseModel):
    """A critical violation in the GSW."""
    type: str = Field(description="Type of violation")
    message: str = Field(description="Short explanation")
    location: str = Field(description="Entity ID, verb phrase ID, or question ID")


class EntityAlignment(BaseModel):
    """Alignment between gold and predicted entities."""
    gold_entity_id: str = Field(description="Gold entity ID")
    pred_entity_id: Optional[str] = Field(description="Predicted entity ID or null", default=None)
    match_type: str = Field(description="Match type: exact, normalized, alias, or none")
    notes: str = Field(description="Brief notes")


class MissingEntity(BaseModel):
    """An entity missing from the predicted GSW."""
    gold_entity_id: str = Field(description="Gold entity ID")
    gold_name: str = Field(description="Gold entity name")
    reason: str = Field(description="Reason: missing, role_mismatch, or bundled")


class ExtraEntity(BaseModel):
    """An extra entity in the predicted GSW."""
    pred_entity_id: str = Field(description="Predicted entity ID")
    pred_name: str = Field(description="Predicted entity name")
    reason: str = Field(description="Reason: hallucinated, unnecessary, or bundled")


class FactDetail(BaseModel):
    """Details of a fact (subject, predicate, object triple)."""
    subject: str = Field(description="Subject entity")
    predicate: str = Field(description="Predicate/verb phrase")
    object: str = Field(description="Object entity")


class MissingFact(BaseModel):
    """A fact missing from the predicted GSW."""
    gold_fact: FactDetail = Field(description="Gold fact details")
    gold_verb_phrase_id: str = Field(description="Gold verb phrase ID")
    reason: str = Field(description="Reason: missing, temporal_missing, wrong_object, or wrong_subject")


class CoveredFact(BaseModel):
    """A fact covered by both gold and predicted GSW."""
    gold_fact: FactDetail = Field(description="Gold fact details")
    pred_fact: FactDetail = Field(description="Predicted fact details")
    gold_verb_phrase_id: str = Field(description="Gold verb phrase ID")
    pred_verb_phrase_id: str = Field(description="Predicted verb phrase ID")
    match_notes: str = Field(description="Brief match notes")


class HallucinatedFact(BaseModel):
    """A hallucinated fact in the predicted GSW."""
    pred_fact: FactDetail = Field(description="Predicted fact details")
    pred_verb_phrase_id: str = Field(description="Predicted verb phrase ID")
    reason: str = Field(description="Reason: not_in_gold, not_entailed_by_text, or over_specific")


class FactComparison(BaseModel):
    """Comparison of facts between gold and predicted GSW."""
    missing_facts: List[MissingFact] = Field(description="Facts missing from predicted GSW", default_factory=list)
    covered_facts: List[CoveredFact] = Field(description="Facts covered by both GSWs", default_factory=list)
    hallucinated_facts: List[HallucinatedFact] = Field(description="Hallucinated facts in predicted GSW", default_factory=list)


class FormatIssue(BaseModel):
    """A format issue in the predicted GSW."""
    type: str = Field(description="Type of format issue")
    message: str = Field(description="Short explanation")
    location: str = Field(description="Verb phrase ID, question ID, or entity ID")


class Judge_Format(BaseModel):
    """LLM-as-a-judge evaluation format for GSW comparisons."""
    overall_score: float = Field(description="Overall score between 0 and 100")
    usable_for_QA: bool = Field(description="True if the GSW is usable for QA, False otherwise")
    subscores: Subscores = Field(description="Subscores for coverage, precision, and format compliance")
    counts: Counts = Field(description="Counts of gold and predicted entities, facts, and violations")
    critical_violations: List[CriticalViolation] = Field(description="List of critical violations", default_factory=list)
    entity_alignment: List[EntityAlignment] = Field(description="List of entity alignments", default_factory=list)
    missing_entities: List[MissingEntity] = Field(description="List of missing entities", default_factory=list)
    extra_entities: List[ExtraEntity] = Field(description="List of extra entities", default_factory=list)
    fact_comparison: FactComparison = Field(description="Comparison of facts between gold and predicted GSW")
    format_issues: List[FormatIssue] = Field(description="List of format issues", default_factory=list)
    improvement_suggestions: List[str] = Field(description="List of improvement suggestions", default_factory=list)

    model_config = {"extra": "forbid"}

In [None]:
judge_client = OpenAI(api_key="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")
all_judgements = []
for i in range(len(test_musique_corpus_sample)):
    outputs = judge_client.chat.completions.parse(
    model="gpt-4.1-mini",
    messages=[
        {"role": "system", "content": LLM_AS_A_JUDGE_PROMPT},
        {"role": "user", "content": LLM_AS_A_JUDGE_PROMPT.format(input_text=test_musique_corpus_sample[0], gold_gsw_json=pred_gsws[0], pred_gsw_json=golden_gsws[0])},
    ],
    response_format=Judge_Format,
)
    all_judgements.append(Judge_Format.model_validate_json(outputs.choices[0].message.content)["overall_score"])

In [None]:
print(Judge_Format.model_validate_json(outputs.choices[0].message.content).model_dump_json(indent=2))


# LoRA/DoRA Fine-Tuning for GSW Creation

This section implements LoRA/DoRA fine-tuning to train a model on GSW creation using thinking mode.

In [None]:
# Import additional libraries for LoRA training
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer

## Prepare Training Data

Match Musique corpus documents to golden GSWs for training.

In [None]:
# Match documents to golden GSWs for training
# For this example, we'll use the first 100 samples

matched_pairs = []
for i, (doc, gsw) in enumerate(zip(test_musique_corpus_sample[:100], golden_gsws[:100])):
    matched_pairs.append({
        "global_id": doc["global_id"],
        "document": doc,
        "gsw": gsw
    })

print(f"Prepared {len(matched_pairs)} document-GSW pairs for training")

## Create HuggingFace Dataset

Format the training data as chat messages for LoRA training.

In [None]:
def create_chat_messages(example):
    """Convert document-GSW pair to chat format."""
    document = example['document']
    gsw = example['gsw']
    
    # Serialize GSW to JSON
    if isinstance(gsw, GSWStructure):
        assistant_response = gsw.model_dump_json(indent=4)
    else:
        assistant_response = json.dumps(gsw, indent=4, ensure_ascii=False)
    
    # Create prompts using FactualExtractionPrompts
    system_prompt = FactualExtractionPrompts.SYSTEM_PROMPT
    user_prompt = FactualExtractionPrompts.USER_PROMPT_TEMPLATE.format(
        input_text=document['text'],
        background_context=""
    )
    
    # Create chat messages (model will add <think> tags during generation)
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": assistant_response},
    ]
    
    return {"messages": messages}

# Create HuggingFace Dataset
raw_dataset = Dataset.from_list(matched_pairs)
training_dataset = raw_dataset.map(
    create_chat_messages,
    remove_columns=raw_dataset.column_names,
    desc="Creating chat-formatted training data"
)

print(f"Training dataset created with {len(training_dataset)} examples")
print(f"Sample messages preview:")
print(json.dumps(training_dataset[0]["messages"][:2], indent=2)[:500], "...")

## Configure Tokenizer

Load and configure the tokenizer for the model.

In [None]:
# Configure model settings
MODEL_ID = "Qwen/Qwen3-8B"  # Change this to your desired model
USE_DORA = False  # Set to True to use DoRA instead of LoRA

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Configure padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.padding_side = 'right'

print(f"Tokenizer configured:")
print(f"  Model: {MODEL_ID}")
print(f"  EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
print(f"  PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
print(f"  BOS token: {tokenizer.bos_token} (ID: {tokenizer.bos_token_id})")
print(f"  Padding side: {tokenizer.padding_side}")

# Test chat template
print("\nTesting chat template...")
sample = training_dataset[0]
formatted = tokenizer.apply_chat_template(
    sample["messages"],
    tokenize=False,
    add_generation_prompt=False
)
print(f"Formatted text length: {len(formatted)}")
print(f"First 500 chars: {formatted[:500]}")
print(f"Last 200 chars: {formatted[-200:]}")

# Check for thinking mode
if '<think>' in formatted or 'add_generation_prompt' in str(tokenizer.chat_template):
    print("\n✓ Thinking mode detected in template")
else:
    print("\n⚠ Note: Thinking tags not detected, model may add them during generation")

## Configure LoRA/DoRA Training

Set up the model, LoRA config, and training arguments.

In [None]:
# Training hyperparameters
OUTPUT_DIR = "./gsw_creation_lora"
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 2e-4
WARMUP_STEPS = 100

# LoRA/DoRA configuration
lora_config = LoraConfig(
    r=256,
    lora_alpha=512,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
    task_type="CAUSAL_LM",
    use_dora=USE_DORA,  # DoRA if True, LoRA if False
)

adapter_type = "DoRA" if USE_DORA else "LoRA"
print(f"Configured {adapter_type} with:")
print(f"  r={lora_config.r}")
print(f"  alpha={lora_config.lora_alpha}")
print(f"  dropout={lora_config.lora_dropout}")
print(f"  target_modules={lora_config.target_modules}")

In [None]:
# Load model
print(f"Loading model: {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

print(f"Model loaded on device: {next(model.parameters()).device}")

In [None]:
# Configure training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    logging_steps=10,
    save_steps=500,
    save_total_limit=3,
    bf16=True,
    fp16=False,
    gradient_checkpointing=True,
    optim="adamw_torch",
    logging_dir=f"{OUTPUT_DIR}/logs",
    report_to="none",  # Change to "wandb" for W&B logging
)

print(f"Training configuration ({adapter_type}):")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * torch.cuda.device_count()}")

## Initialize Trainer and Start Training

Create the SFTTrainer and begin LoRA fine-tuning.

In [None]:
def formatting_function(example):
    """Format a single example using the tokenizer's chat template."""
    return tokenizer.apply_chat_template(
        example["messages"],
        tokenize=False,
        add_generation_prompt=False
    )

# Initialize SFTTrainer
print("Initializing trainer...")
trainer = SFTTrainer(
    model=model,
    args=training_args,
    peft_config=lora_config,
    processing_class=tokenizer,
    train_dataset=training_dataset,
    formatting_func=formatting_function,
)

print(f"\n{'='*60}")
print(f"Starting {adapter_type} training...")
print(f"{'='*60}\n")

# Start training
trainer.train()

print("\nTraining complete!")

## Save the Trained Model

Save the final LoRA adapters and optionally push to HuggingFace Hub.

In [None]:
# Save the trained model
final_output_path = f"{OUTPUT_DIR}/final"
print(f"Saving final model to: {final_output_path}")
trainer.save_model(final_output_path)

print(f"\n✓ Model saved successfully!")
print(f"  Location: {final_output_path}")
print(f"  Adapter type: {adapter_type}")

# Optional: Push to HuggingFace Hub
# Uncomment and configure the following to push to Hub:
# HUB_MODEL_ID = "username/qwen3-gsw-creation-lora"
# print(f"\nPushing to HuggingFace Hub: {HUB_MODEL_ID}")
# trainer.push_to_hub(
#     commit_message=f"Training complete - {adapter_type} fine-tuned model for GSW creation",
#     blocking=True,
# )
# print(f"✓ Model pushed to: https://huggingface.co/{HUB_MODEL_ID}")

## Test the Fine-Tuned Model

Test the trained model on a new document to verify it generates proper GSW structures.

In [None]:
# Test the model on a sample document
test_doc = test_musique_corpus_sample[0]

# Create test prompt
test_messages = [
    {"role": "system", "content": FactualExtractionPrompts.SYSTEM_PROMPT},
    {"role": "user", "content": FactualExtractionPrompts.USER_PROMPT_TEMPLATE.format(
        input_text=test_doc['text'],
        background_context=""
    )}
]

# Format with chat template
test_input = tokenizer.apply_chat_template(
    test_messages,
    tokenize=False,
    add_generation_prompt=True
)

# Tokenize
inputs = tokenizer(test_input, return_tensors="pt").to(model.device)

# Generate
print("Generating GSW output...")
print("="*60)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=2048,
        temperature=0.6,
        do_sample=True,
        top_p=0.95,
        top_k=20,
    )

# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

# Extract the assistant's response
assistant_start = generated_text.rfind("<|im_start|>assistant")
if assistant_start != -1:
    assistant_response = generated_text[assistant_start:]
    print("Generated output:")
    print(assistant_response[:1000])  # Print first 1000 chars
    print("\n... [truncated] ...")
else:
    print("Full generated text:")
    print(generated_text)

print("\n" + "="*60)
print("Test complete!")