In [None]:
# !pip install openpipe-art==0.5.0 langchain-core tenacity datasets vllm faiss-cpu chromadb requests lxml numpy transformers torch gql==3.4.1 peft
!pip install peft

In [2]:
import os
from secretsConfig import oaiKey, wandbKey  # Import the variables

# Required for RULER judge model
os.environ["OPENAI_API_KEY"] = oaiKey

# Required for Weights & Biases
os.environ["WANDB_API_KEY"] = wandbKey

if not os.environ.get("OPENAI_API_KEY"):
    raise ValueError(
        "OPENAI_API_KEY is required for RULER functionality when using openai/o4-mini."
    )

if not os.environ.get("WANDB_API_KEY"):
    raise ValueError("WANDB_API_KEY is required for inference, training, and logging to Weights & Biases.")

In [3]:
from IBM_Z_Datathon_RAG.semantic_search import FAISSSemanticSearch
from IBM_Z_Datathon_RAG.KeywordSearch import keyword_search
from IBM_Z_Datathon_RAG.ReadDocumentPart import read_document_part



In [4]:
from dotenv import load_dotenv
import random

import art
from art.serverless.backend import ServerlessBackend

load_dotenv()

random.seed(42)

# Declare the model - CHANGED TO QWEN3-14B
model = art.TrainableModel(
    name="legal-agent-001",
    project="legal-rag",
    base_model="Qwen/Qwen2.5-14B-Instruct",  # Changed from Qwen2.5-14B-Instruct
)

# Initialize the server
# Training and inference will run on Weights & Biases servers
backend = ServerlessBackend()

# Register the model with the Serverless Backend (sets up logging, inference, and training)
await model.register(backend)

In [7]:
import os
import json
from textwrap import dedent
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from openai import AsyncOpenAI
from litellm import acompletion
from langchain_core.utils.function_calling import convert_to_openai_tool
from tenacity import retry, stop_after_attempt
import art

# Your tool imports
from IBM_Z_Datathon_RAG.semantic_search import FAISSSemanticSearch
from IBM_Z_Datathon_RAG.KeywordSearch import keyword_search
from IBM_Z_Datathon_RAG.ReadDocumentPart import read_document_part

load_dotenv()

MAX_TURNS = 4


class FinalAnswer(BaseModel):
    answer: str = Field(description="The final answer to the legal question")
    source_ids: list[str] = Field(description="List of part IDs used as sources")


class LegalScenario(BaseModel):
    id: str
    question: str
    gold_answer: str | None = None
    gold_doc_ids: list[str] | None = None
    gold_part_ids: list[str] | None = None


class CorrectnessJudgeResponse(BaseModel):
    reasoning: str = Field(description="Explanation of the reasoning process.")
    accept: bool = Field(description="Whether the AI answer should be accepted.")


@retry(stop=stop_after_attempt(3))
async def judge_correctness(
    scenario: LegalScenario, answer: str
) -> CorrectnessJudgeResponse:
    system_prompt = dedent(
        """
        You are given a legal question, the reference answer, and an AI answer.
        Decide whether the AI answer is correct and should be accepted.
        """
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": (
                f"Question: {scenario.question}\n"
                f"Reference answer: {scenario.gold_answer}\n"
                f"AI answer: {answer}"
            ),
        },
    ]

    response = await acompletion(
        model="openai/gpt-4o-mini",
        messages=messages,
        response_format=CorrectnessJudgeResponse,
    )

    first_choice = response.choices[0]
    raw_content = first_choice.message.content or "{}"

    try:
        return CorrectnessJudgeResponse.model_validate_json(raw_content)
    except Exception as e:
        return CorrectnessJudgeResponse(
            reasoning=f"Parse error: {e}", accept=False
        )


class ProjectTrajectory(art.Trajectory):
    final_answer: FinalAnswer | None = None


class LegalScenarioStep(BaseModel):
    step: int
    scenario: LegalScenario


async def rollout(model: art.Model, legal_scenario_step: LegalScenarioStep) -> ProjectTrajectory:
    """This is the rollout function - defines how your agent interacts with tools"""
    scenario = legal_scenario_step.scenario

    traj = ProjectTrajectory(
        reward=0.0,
        messages_and_choices=[],
        metadata={
            "scenario_id": scenario.id,
            "step": legal_scenario_step.step,
        },
    )

    system_prompt = dedent(
        f"""
        You are a legal research agent. Use the available tools to search legal documents and answer the question.
        You may take up to {MAX_TURNS} turns.

        Tools:
        - search_keyword: Exact term matches
        - search_semantic: Conceptual searches
        - read_document_part: Read full text using part_id

        Always cite sources using part_ids.
        """
    )

    traj.messages_and_choices = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": scenario.question},
    ]

    # Define tools
    def search_keyword_tool(query: str, num: int = 5) -> str:
        """Search using keyword/BM25 for exact term matches."""
        return keyword_search(query, num)

    def search_semantic_tool(query: str, num: int = 5) -> str:
        """Search using semantic/vector search."""
        searcher = FAISSSemanticSearch()
        return searcher.search(query, num)

    def read_document_part_tool(part_id: str) -> str:
        """Read a document part by ID."""
        return read_document_part(part_id)

    def return_final_answer(answer: str, source_ids: list[str]) -> FinalAnswer:
        """Return final answer with sources."""
        return FinalAnswer(answer=answer, source_ids=source_ids)

    tools = [search_keyword_tool, search_semantic_tool, read_document_part_tool, return_final_answer]
    tools_by_name = {t.__name__: t for t in tools}
    traj.tools = [convert_to_openai_tool(t) for t in tools]

    client = AsyncOpenAI(
        base_url=model.inference_base_url,
        api_key=model.inference_api_key,
    )

    for _ in range(MAX_TURNS):
        response = await client.chat.completions.create(
            model=model.get_inference_name(),
            temperature=1,
            messages=traj.messages(),
            tools=traj.tools,
        )

        response_message = response.choices[0].message
        traj.messages_and_choices.append(response.choices[0])

        if not response_message.tool_calls:
            return traj

        try:
            for tool_call in response_message.tool_calls:
                tool_name: str = tool_call.function.name
                if tool_name in tools_by_name:
                    tool_args = json.loads(tool_call.function.arguments)
                    tool_to_call = tools_by_name[tool_name]
                    result = tool_to_call(**tool_args)
                    traj.messages_and_choices.append(
                        {
                            "role": "tool",
                            "tool_call_id": tool_call.id,
                            "name": tool_name,
                            "content": str(result),
                        }
                    )

                    if tool_name == "return_final_answer":
                        traj.final_answer = result
                        # Score trajectory
                        if traj.final_answer and scenario.gold_answer:
                            judge_response = await judge_correctness(
                                scenario, traj.final_answer.answer
                            )
                            traj.metrics["correct"] = float(judge_response.accept)
                        return traj
        except Exception as e:
            print(f"Error: {e}")
            return traj

    return traj


# LOCAL TRAINING - uses YOUR A100s
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import torch

model_name = "Qwen/Qwen2.5-14B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # Automatically uses all 4 A100s
    trust_remote_code=True
)

# Apply LoRA
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print("✅ Model loaded on YOUR A100s with LoRA!")

ModuleNotFoundError: No module named 'peft'