In [5]:
# !pip install openpipe-art==0.5.0 langchain-core tenacity datasets vllm faiss-cpu chromadb requests lxml numpy transformers torch gql==3.4.1 peft 
!pip install langchain-core tenacity datasets vllm

Collecting langchain-core
  Using cached langchain_core-0.3.79-py3-none-any.whl.metadata (3.2 kB)
Collecting datasets
  Using cached datasets-4.2.0-py3-none-any.whl.metadata (18 kB)
Collecting vllm
  Using cached vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl.metadata (17 kB)
Collecting langsmith<1.0.0,>=0.3.45 (from langchain-core)
  Using cached langsmith-0.4.34-py3-none-any.whl.metadata (14 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Using cached pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.4.1,>=0.3.0 (from datasets)
  Using cached dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Using cached pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Using cached multi

In [1]:
import os
from secretsConfig import oaiKey, wandbKey, openRouterKey  # Add openRouterKey

# Required for RULER judge model
os.environ["OPENAI_API_KEY"] = oaiKey

# Required for Weights & Biases
os.environ["WANDB_API_KEY"] = wandbKey

# Required for OpenRouter (Gemini judge)
os.environ["OPENROUTER_API_KEY"] = openRouterKey  # ADD THIS LINE

if not os.environ.get("OPENAI_API_KEY"):
    raise ValueError("OPENAI_API_KEY is required for RULER functionality.")

if not os.environ.get("WANDB_API_KEY"):
    raise ValueError("WANDB_API_KEY is required for W&B.")

if not os.environ.get("OPENROUTER_API_KEY"):
    raise ValueError("OPENROUTER_API_KEY is required for Gemini judge.")

In [2]:
from IBM_Z_Datathon_RAG.semantic_search import FAISSSemanticSearch
from IBM_Z_Datathon_RAG.KeywordSearch import keyword_search
from IBM_Z_Datathon_RAG.ReadDocumentPart import read_document_part



In [3]:
from dotenv import load_dotenv
import random

import art
from art.serverless.backend import ServerlessBackend

load_dotenv()

random.seed(42)

# Declare the model - CHANGED TO QWEN3-14B
model = art.TrainableModel(
    name="legal-agent-001",
    project="legal-rag",
    base_model="Qwen/Qwen2.5-14B-Instruct",  # Changed from Qwen2.5-14B-Instruct
)

# Initialize the server
# Training and inference will run on Weights & Biases servers
backend = ServerlessBackend()

# Register the model with the Serverless Backend (sets up logging, inference, and training)
await model.register(backend)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from textwrap import dedent
from pydantic import BaseModel, Field
from openai import AsyncOpenAI
from langchain_core.utils.function_calling import convert_to_openai_tool
import art

MAX_TURNS = 4


class FinalAnswer(BaseModel):
    answer: str
    source_ids: list[str]


class LegalScenario(BaseModel):
    id: str
    question: str
    gold_answer: str | None = None
    gold_part_ids: list[str] | None = None


class LegalScenarioStep(BaseModel):
    step: int
    scenario: LegalScenario


async def rollout(model: art.Model, legal_scenario_step: LegalScenarioStep) -> art.Trajectory:
    """Execute one trajectory rollout"""
    scenario = legal_scenario_step.scenario
    
    traj = art.Trajectory(
        reward=0.0,
        messages_and_choices=[],
        metadata={"scenario_id": scenario.id, "step": legal_scenario_step.step},
    )

    # YOUR CUSTOM PROMPT HERE
    system_prompt = dedent(
        f"""
        You are a legal research assistant that can search legal documents to answer questions.

        You have access to the following tools:

        - search_keyword(query: str, num: int) -> str: Search using keyword/BM25 search for exact term matches.
        - search_semantic(query: str, num: int) -> str: Search using semantic/vector search for conceptual similarity.
        - read_document_part(part_id: str) -> str: Read a document part by ID. Part IDs use hierarchical format (e.g., A:B:C). To access parent parts, remove the last segment (e.g., A:B:C → parent is A:B).

        You may call one tool per turn, for up to {MAX_TURNS} turns, before giving your final answer.

        In each turn, you should analyze what information you need and respond with EITHER a tool call OR your final answer.

        For tool calls, use this format:
        <think>
        [your reasoning for what to search for and why]
        </think>
        <tool>
        {{"name": "tool_name", "args": {{"query": "search query"}}}}
        </tool>

        When you have enough information, give your final answer in this format:

        <think>
        [your reasoning for the answer]
        </think>
        <answer>
        [your comprehensive answer citing the evidence you found or "I don't know" if you didn't get enough information]

        <sources>
        <source>doc_id_1</source>
        </sources>
        </answer>
        """
    )

    traj.messages_and_choices = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": scenario.question},
    ]

    # Define tools
    def search_keyword_tool(query: str, num: int = 5) -> str:
        return keyword_search(query, num)

    def search_semantic_tool(query: str, num: int = 5) -> str:
        searcher = FAISSSemanticSearch()
        return searcher.search(query, num)

    def read_document_part_tool(part_id: str) -> str:
        return read_document_part(part_id)

    def return_final_answer(answer: str, source_ids: list[str]) -> FinalAnswer:
        return FinalAnswer(answer=answer, source_ids=source_ids)

    tools = [search_keyword_tool, search_semantic_tool, read_document_part_tool, return_final_answer]
    tools_by_name = {t.__name__: t for t in tools}
    traj.tools = [convert_to_openai_tool(t) for t in tools]

    client = AsyncOpenAI(
        base_url=model.inference_base_url,
        api_key=model.inference_api_key,
    )

    for _ in range(MAX_TURNS):
        response = await client.chat.completions.create(
            model=model.get_inference_name(),
            temperature=1,
            messages=traj.messages(),
            tools=traj.tools,
        )

        response_message = response.choices[0].message
        traj.messages_and_choices.append(response.choices[0])

        if not response_message.tool_calls:
            return traj

        try:
            for tool_call in response_message.tool_calls:
                tool_name = tool_call.function.name
                if tool_name in tools_by_name:
                    tool_args = json.loads(tool_call.function.arguments)
                    result = tools_by_name[tool_name](**tool_args)
                    traj.messages_and_choices.append({
                        "role": "tool",
                        "tool_call_id": tool_call.id,
                        "name": tool_name,
                        "content": str(result),
                    })

                    if tool_name == "return_final_answer":
                        return traj
        except Exception as e:
            print(f"Error: {e}")
            return traj

    return traj


print("✅ Rollout function defined!")

In [6]:
import json
import os
from litellm import acompletion

# Load your training data
DATA_FILE = "./snippet_data.json"

print(f"Loading data from {DATA_FILE}...")
with open(DATA_FILE, 'r') as f:
    data = json.load(f)

# Convert to LegalScenario objects
training_scenarios = []
for item in data.get("items", []):
    for row in item.get("rows", []):
        sources = row.get("sources", [])
        gold_part_ids = sources if sources else []
        
        training_scenarios.append(
            LegalScenario(
                id=str(row["row_index"]),
                question=row["question"],
                gold_answer=row.get("model_answer", ""),
                gold_part_ids=gold_part_ids
            )
        )

print(f"✅ Loaded {len(training_scenarios)} scenarios")


# Custom RULER function using OpenRouter
async def gemini_ruler_score_group(group: art.TrajectoryGroup) -> art.TrajectoryGroup:
    """Score trajectories using Gemini 2.5 Flash via OpenRouter"""
    
    trajectories = group.trajectories
    if len(trajectories) <= 1:
        for traj in trajectories:
            traj.reward = 0.0
        return group
    
    # Extract responses
    responses = []
    for traj in trajectories:
        messages = traj.messages()
        if messages:
            last_msg = messages[-1].get("content", "")
            responses.append(last_msg)
        else:
            responses.append("")
    
    # Build comparison prompt
    comparison_text = "\n\n".join([
        f"**Response {i+1}:**\n{resp[:500]}"  # Truncate for API limits
        for i, resp in enumerate(responses)
    ])
    
    judge_prompt = f"""Compare these {len(responses)} legal research responses and rank them.

Criteria:
1. Correctness and accuracy (most important)
2. Proper citation of sources with part_ids
3. Completeness of answer

Responses:
{comparison_text}

Return ONLY a JSON array of scores from 0.0 to 2.0, one score per response in order.
Higher scores = better responses.
Example: [2.0, 0.5, 1.5]

Your scores:"""
    
    try:
        # Call Gemini via OpenRouter
        response = await acompletion(
            model="google/gemini-2.0-flash-exp:free",  # Free tier
            messages=[{"role": "user", "content": judge_prompt}],
            api_base="https://openrouter.ai/api/v1",
            api_key=os.environ["OPENROUTER_API_KEY"],
            max_tokens=100,
        )
        
        # Parse scores
        result_text = response.choices[0].message.content.strip()
        
        # Extract JSON array
        import re
        json_match = re.search(r'\[[\d\.,\s]+\]', result_text)
        if json_match:
            scores = json.loads(json_match.group())
        else:
            scores = json.loads(result_text)
        
        # Assign scores
        for traj, score in zip(trajectories, scores):
            traj.reward = float(score)
        
        print(f"  Scores: {scores}")
        
    except Exception as e:
        print(f"  Error in judge: {e}")
        # Fallback: random variation
        import random
        for traj in trajectories:
            traj.reward = random.uniform(0.5, 1.5)
    
    return group


# Test the judge
print("\n🧪 Testing Gemini judge via OpenRouter...")

test_scenario = training_scenarios[0]
base_messages = [
    {"role": "system", "content": "You are a legal research agent."},
    {"role": "user", "content": test_scenario.question},
]

good_traj = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {"role": "assistant", "content": test_scenario.gold_answer},
    ],
    reward=0,
)

bad_traj = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {"role": "assistant", "content": "I don't know anything about this legal question."},
    ],
    reward=0,
)

test_group = art.TrajectoryGroup(trajectories=[good_traj, bad_traj])

# Score using custom function
judged_group = await gemini_ruler_score_group(test_group)

# Display results
sorted_trajs = sorted(judged_group.trajectories, key=lambda t: t.reward, reverse=True)
for rank, traj in enumerate(sorted_trajs, 1):
    msgs = traj.messages()
    print(f"\nRank {rank}: Score {traj.reward:.3f}")
    print(f"  Response: {msgs[-1]['content'][:80]}...")

print("\n✅ Gemini judge working!")

Loading data from ./snippet_data.json...
✅ Loaded 100 scenarios

🧪 Testing Gemini judge via OpenRouter...

[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m

  Error in judge: litellm.BadRequestError: LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model=google/gemini-2.0-flash-exp:free
 Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers

Rank 1: Score 1.139
  Response: The Marshall Court reasoned that a land grant from a state constitutes a binding...

Rank 2: Score 0.525
  Response: I don't know anything about this legal question....

✅ Gemini judge working!


In [7]:
from art.utils import iterate_dataset

# Training config
training_config = {
    "groups_per_step": 2,
    "num_epochs": 3,
    "rollouts_per_group": 6,
    "learning_rate": 1e-5,
    "max_steps": 50,
}

# Create training iterator
training_iterator = iterate_dataset(
    training_scenarios,
    groups_per_step=training_config["groups_per_step"],
    num_epochs=training_config["num_epochs"],
    initial_step=await model.get_step(),
)

print("🚀 Starting training loop...\n")

for batch in training_iterator:
    print(f"=== Step {batch.step} | Epoch {batch.epoch} | Epoch Step {batch.epoch_step} ===")
    print(f"Batch: {len(batch.items)} scenarios")
    
    # Create trajectory groups
    groups = []
    for scenario in batch.items:
        groups.append(
            art.TrajectoryGroup(
                (
                    rollout(model, LegalScenarioStep(step=batch.step, scenario=scenario))
                    for _ in range(training_config["rollouts_per_group"])
                )
            )
        )
    
    # Gather trajectories
    finished_groups = await art.gather_trajectory_groups(
        groups,
        pbar_desc="Gathering trajectories",
        max_exceptions=training_config["rollouts_per_group"] * len(batch.items),
    )
    
    # Judge with RULER (Gemini 2.5 Flash)
    judged_groups = []
    for group in finished_groups:
        judged_group = await ruler_score_group(
            group,
            "google/gemini-2.5-flash",
            debug=True
        )
        judged_groups.append(judged_group)
    
    # Train on judged trajectories
    await model.delete_checkpoints()
    await model.train(
        judged_groups,
        config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
    )
    
    # Calculate metrics
    all_rewards = [t.reward for g in judged_groups for t in g.trajectories]
    avg_reward = sum(all_rewards) / len(all_rewards)
    
    print(f"✅ Step {batch.step} complete | Avg Reward: {avg_reward:.3f}\n")
    
    # Stop after max_steps
    if batch.step >= training_config["max_steps"]:
        break

print("🎉 Training complete!")

🚀 Starting training loop...



Iterating dataset:   0%|          | 0/150 [00:00<?, ?batch/s]

=== Step 0 | Epoch 0 | Epoch Step 0 ===
Batch: 2 scenarios


[ERROR] Failed to parse XML: [Errno 2] No such file or directory: 'Plessy v. Ferguson separate but equal doctrine Justice Harlan dissent'


: 