# Regulatory Information Retrieval and Answer Generation (RIRAG)

This notebook solves the following task of the Regulatory Information Retrieval and Answer Generation competition.

_Using the question and the passages retrieved in Subtask 1 (See ObliQA.ipynb notebook), participants must generate a comprehensive, accurate, and coherent answer. This subtask emphasizes the ability to synthesize information from multiple sources and present it in a clear and logical manner, ensuring that the answer fully addresses the compliance and obligation requirements of the query._

The notebook demonstrates how we can leverage _Retrieval Augmented Generation_ and _Large Language Models_ to synthesize the results obtained through the hybrid (lexical and semantic) search to provide an accurate and precise answer to help professionals navigate the regulatory content.

In [1]:
# Copy RePASs repo to validate our results - ONLY RUN ONCE
#!git clone https://github.com/RegNLP/RePASs.git && cd RePASs

In [3]:
# Load passages from disk
ndocs = 40  # Number of regulatory documents to process
passages = defaultdict(str) # List to store all passages extracted from the regulatory documents

# Extract the passages in each document
for i in range(1, ndocs + 1):
    with open(os.path.join("ObliQADataset/StructuredRegulatoryDocuments", f"{i}.json")) as f:
        doc = json.load(f)  # Loads the contents of the JSON file
        for psg in doc:  # Map each passageId to the actual content
            passages[psg["ID"]] = psg["Passage"]

In [4]:
rankings_dict = defaultdict(list) # Maps a question to the relevant passage and its corresponding ranking score

# Load the rankings file in memory
with open('data/rankings_hybrid.trec', 'r') as f:
    # File format: QuestionID Q0 DocumentID Rank Score Method
    for line in f:
        parts = line.strip().split()
        question_id = parts[0]
        document_id = parts[2]
        rank = int(parts[3])
        score = float(parts[4])
        rankings_dict[question_id].append({
            'doc': document_id,
            'score': score
        })

In [19]:
def extract_passages(question_id: str, rankings_dict: dict = rankings_dict) -> list[str]:
    """
    Extracts the passages content that are relevant for answering the given question.
    Given a valid question id, it returns at least one passage and up to 10 passages
    that surpass a given relevance threshold.
    
    Args:
        question_id: The question id for which we want to extract the relevant passages
        
    Returns:
        List[str]: A list of passages that are relevant for answering the given question
    """
    retrieved_passages = []
    should_stop = False
    
    for i in range(len(rankings_dict[question_id])):
        # If there was a significant difference in relevance between two passages, don't extract more passages
        # If 10 passages have already been extracted, don't extract more
        if should_stop or len(retrieved_passages) == 10:
            break
            
        # If no passage has been extracted, extract at least one
        if len(retrieved_passages) == 0:
            retrieved_passages.append(rankings_dict[question_id][i]["doc"])
            continue
                
        # Check if there is a relevance difference between this and the next passage of more than 10%
        if i < len(rankings_dict[question_id]) - 1 and rankings_dict[question_id][i]["score"] - rankings_dict[question_id][i+1]["score"] > 0.1:
                should_stop = True

        # Don't include passages with low relevance
        if rankings_dict[question_id][i]["score"] < 0.72:
            break

        retrieved_passages.append(rankings_dict[question_id][i]["doc"])
        
    # Extract the plain text
    retrieved_passages = [passages[doc] for doc in retrieved_passages]
    
    return retrieved_passages

In [6]:
def build_prompt(question: str, relevant_passages: list[str], system_prompt: str = None) -> tuple[str, str]:
    """
    Builds the prompt that will be used to synthesize the passages
    
    Args:
        question: A well formed regulatory question
        relevant_passages: A list of relevant passages that should help answer the question
        system_prompt: Optional custom system prompt. If None, uses default regulatory compliance prompt
    
    Returns:
        A tuple with both the system prompt that contains instructions on how to answer and
        the user prompt that contains the actual question and passages
    """
    
    # Default system prompt if none is provided
    default_system_prompt = ("You are a regulatory compliance assistant. Provide a **complete**, **coherent**, and"
    "**correct** response to the given question by synthesizing the information from the provided passages. "
    "Your answer should **fully integrate all relevant obligations, practices, and insights**, and directly"
    "address the question. The passages are presented in order of relevance, so **prioritize the information"
    "accordingly** and ensure consistency in your response, avoiding any contradictions. Additionally, reference"
    "**specific regulations and key compliance requirements** outlined in the regulatory content to support your"
    "answer. **Do not use any extraneous or external knowledge** outside of the provided passages when crafting"
    "your response.")
    
    
    # Use provided system prompt or fall back to default
    system_prompt = system_prompt if system_prompt is not None else default_system_prompt

    user_prompt = f"Question: {question}\n\n"
    for passage in relevant_passages:
        user_prompt += f"Passage: {passage}\n\n"
        
    return (system_prompt, user_prompt)

build_prompt("question", ["passage"])

('You are a regulatory compliance assistant. Provide a **complete**, **coherent**, and**correct** response to the given question by synthesizing the information from the provided passages. Your answer should **fully integrate all relevant obligations, practices, and insights**, and directlyaddress the question. The passages are presented in order of relevance, so **prioritize the informationaccordingly** and ensure consistency in your response, avoiding any contradictions. Additionally, reference**specific regulations and key compliance requirements** outlined in the regulatory content to support youranswer. **Do not use any extraneous or external knowledge** outside of the provided passages when craftingyour response.',
 'Question: question\n\nPassage: passage\n\n')

## Azure OpenAI - Standard deployment

First, we use Azure OpenAI standard deployment to synthesize the retrieved passages for each question using `gpt 3.5 turbo`.
Since the API has enabled rate limit for both token and requests per minute, we use a decorator to throttle the function in the client to prevent sending requests that will be blocked.

In [2]:
# Import azure open AI library to use access API-based LLMs. 
from openai import AzureOpenAI

# Import data structures used for throttling implementation for standard deployments in Azure.
from collections import (
    defaultdict,
    deque
)
from threading import Lock
import asyncio

# Import util libraries
from tqdm import tqdm
from dotenv import load_dotenv
import os
import json
import time

In [15]:
# Load environment variables to handle access to the openAI API using secrets
# Variables to be defined:
# QNA_ENDPOINT_URL: Deployment endpoint for inference/chat completion
# QNA_OPENAI_API_KEY: Key to access openAI API
#
load_dotenv()

True

In [16]:
endpoint = os.getenv('QNA_ENDPOINT_URL')
openAIKey = os.getenv('QNA_OPENAI_API_KEY')
llm_model = 'gpt-35-turbo'

if not endpoint:
    raise ValueError("The environment variable QNA_ENDPOINT_URL is not defined.")

if not openAIKey:
    raise ValueError("The environment variable QNA_OPENAI_API_KEY is not defined.")

openAI_client = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=openAIKey,
    api_version="2024-05-01-preview"
)

# This class limits the number of times a function can be called in a given time interval. It guarantees that
# the functions are called all the time, but not in the same order as they are called.
# Limits the number of times we can call a function in a given time interval. 
# Guarantees that a function call will eventually happen, but it does not necessarily respect the order in which 
# the function was called
class Throttle:
    def __init__(self, rate_limit, time_window):
        self.rate_limit = rate_limit # Max number of calls allowed in the given interval
        self.time_window = time_window # The time interval
        self.calls = deque() # Stores the function calls so we can track when to remove calls as they become stale
        self.lock = Lock() # Locks to prevent concurrent access to the function
        self.queue = asyncio.Queue() # Stores function calls that need to be awaited before being executed

    def __call__(self, func):
        async def wrapped_func(*args, **kwargs):
            # Reference the global func
            nonlocal func
            # Lock concurrent access
            with self.lock:
                current_time = time.time()
                
                # Remove function calls that are outside of the time window
                while self.calls and self.calls[0] < current_time - self.time_window:
                    self.calls.popleft()

                # If we can make a call without exceeding the rate limit, then we just call it   
                if len(self.calls) < self.rate_limit:
                    self.calls.append(current_time)
                    return await func(*args, **kwargs)
                else:
                    # Otherwise, queue the function call for later
                    await self.queue.put((func, args, kwargs))
                    
                    # Process function calls in the queue
                    while not self.queue.empty():
                        # Dequeue the function call
                        func, args, kwargs = await self.queue.get()
                        current_time = time.time()

                        # Remove function calls that are outside of the time window
                        while self.calls and self.calls[0] < current_time - self.time_window:
                            self.calls.popleft()

                        # If we can make a call without exceeding the rate limit, then we just call it
                        if len(self.calls) < self.rate_limit:
                            self.calls.append(current_time)
                            result = await func(*args, **kwargs)
                            self.queue.task_done()
                            return result
                        else:
                            # Otherwise, wait a few seconds before retrying 
                            await self.queue.put((func, args, kwargs)) # Enqueue the function call again
                            await asyncio.sleep(min(10, self.time_window))
                    
        return wrapped_func

# Allow maximum 60 calls every 70 seconds
@Throttle(rate_limit=60, time_window=70)
async def summarize_answer(question: str, relevant_passages: list[str]) -> str:
    """
    Summarizes the answer based on the provided passages
    
    Args:
        question: A well formed regulatory question
        relevant_passages: A list of relevant passages that should help answer the question
    """

    (system_prompt, user_prompt) = build_prompt(question, relevant_passages)

    # Executes the LLM API call
    completion = openAI_client.chat.completions.create(
        model=llm_model, # we are using gpt-3.5-turbo
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.25, # Controls how deterministic the response is. Higher values result in more creativity. We want more easy-to-reproduce results
        frequency_penalty=0.0,
        presence_penalty=0.0,
        stop=None,
        stream=False,
        max_tokens=800,
    )

    return completion.choices[0].message.content

ValueError: The environment variable QNA_ENDPOINT_URL is not defined.

In [10]:
answers = []

# load the test dataset
with open("ObliQADataset/ObliQA_test.json") as f:
    data = json.load(f)  # Load the JSON file
    
    # For each question:
    for e in tqdm(data):  # tqdm adds a progress bar
        query = e['Question']  # Extract the actual question
        question_id = e["QuestionID"] # Extract the question id
        
        retrieved_passages = extract_passages(question_id)

        answer = await summarize_answer(query, retrieved_passages)

        answers.append({
            "QuestionID": question_id,
            "RetrievedPassages": retrieved_passages,
            "Answer": answer
        })

# Store the results in a json File
with open("data/answers.json", "w") as f:
    json.dump(answers, f, indent=2)        

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2786/2786 [1:56:08<00:00,  2.50s/it]


## Azure OpenAI - Batch deployment

Second, we use Azure OpenAI batch deployment to synthesize the retrieved passages for each question using `gpt-4o-mini`.
We leverage the batch API to send all queries at once. The response is retrieved offline from the Azure open AI portal.

In [7]:
def queue_batch_summarization_job(jobs):
    """
    Create a batch Job in Azure open AI to generate the answers for all questions with a single request
    """
    endpoint = os.getenv('QNA_ENDPOINT_URL')
    openAIKey = os.getenv('QNA_OPENAI_API_KEY')

    if not endpoint:
        raise ValueError("No se ha definido la variable de entorno QNA_ENDPOINT_URL")

    if not openAIKey:
        raise ValueError("No se ha definido la variable de entorno QNA_OPENAI_API_KEY")

    openAI_client = AzureOpenAI(
        azure_endpoint=endpoint,
        api_key=openAIKey,
        api_version="2024-08-01-preview"
    )

    # File that can either be uploaded manually or programatically
    file_name = "data/batch_questions.jsonl"

    # Save file contents using json lines format
    with open(file_name, 'w') as file:
        for job in jobs:
            file.write(json.dumps(job) + '\n')

    # Upload the file programatically
    batch_file = openAI_client.files.create(
      file=open(file_name, "rb"),
      purpose="batch"
    )
    
    # Wait until the file upload is done
    while True:
        file = openAI_client.files.retrieve(batch_file.id)
        if file.status == "processed" or file.status == "error":
            break
        time.sleep(10)
    
    # Trigger the batch job using the uploaded file
    # Result should terminate in less than 24 hours
    batch_job = openAI_client.batches.create(
      input_file_id=batch_file.id,
      endpoint="/v1/chat/completions",
      completion_window="24h"
    )
    
    return batch_job

In [11]:
system_prompt_one_shot = ("As a regulatory compliance assistant. Provide a **complete**, **coherent**, and "
"**correct** response to the given question by synthesizing the information from the provided passages. "
"Your answer should **fully integrate all relevant obligations, practices, and insights**, and directly "
"address the question. The passages are presented in order of relevance, so **prioritize the information "
"accordingly** and ensure consistency in your response, avoiding any contradictions. Additionally, reference "
"**specific regulations and key compliance requirements** outlined in the regulatory content to support your "
"answer. **Do not use any extraneous or external knowledge** outside of the provided passages when crafting "
"your response."
"/n/nHere are a few examples."
"/n/nQuestion: What specific areas of inventory and delivery infrastructure should be covered in the independent third-party audits to satisfy the requirements of COBS Rule 22.4.2(d)?"
"/n/nPassage: REGULATORY REQUIREMENTS - SPOT COMMODITY ACTIVITIES\nDelivery & Storage\nWhen applying COBS Rule 22.4.2(d), an Authorised Person should have independent third party audits carried out at appropriate times, for the inventories and deliveries undertaken at the storage facility, as well as the facilities infrastructure itself.  Where necessary, further third-party audits will be required for the obligations of Accepted Spot Commodities, as outlined in paragraph 26 above.\n"
"/n/nPassage: REGULATORY REQUIREMENTS - SPOT COMMODITY ACTIVITIES\nDelivery & Storage\nPursuant to COBS Rule 22.4.1, a delivery and/or storage facility used by an Authorised Person can be operated from within ADGM or outside ADGM.  Specifically, for the purposes of COBS Rules 22.4.1, an Authorised Person will need to submit to the FSRA the details of how each delivery and storage facility that it proposes to use, whether located inside or outside ADGM, meets the requirements set out in Rule 22.4.2(a) to (e).\n"
"/n/nYour response should read:"
"/n/nTo satisfy the requirements of COBS Rule 22.4.2(d) for independent third-party audits, an Authorised Person should ensure that the audits cover the inventories and deliveries undertaken at the storage facility, as well as the infrastructure of the facility itself. Additionally, if the Authorised Person deals with Accepted Spot Commodities, further third-party audits will be necessary to fulfill their obligations as outlined in paragraph 26. As per COBS Rule 22.4.1, the delivery and storage facility used by the Authorised Person can be located within or outside ADGM, and the Authorised Person must submit details to the FSRA on how each facility meets the requirements set out in Rule 22.4.2(a) to (e). Therefore, the independent third-party audits should cover the areas of inventory, delivery, and infrastructure of the storage facility, as well as any obligations related to Accepted Spot Commodities and compliance with the requirements set out in COBS Rule 22.4.2(a) to (e)"
"/n/nQuestion: What percentage of the Insurer's Net Written Premium is used to determine the non-proportional reinsurance element?"
"/n/nPassage: The non proportional reinsurance element is calculated as 52% of the Insurer's Net Written Premium"
"/n/nYour response should read:"
"/n/nThe non-proportional reinsurance element is determined by calculating 52% of the Insurer's Net Written Premium."
"/n/nQuestion: Who is responsible for ensuring compliance with the obligations that apply to the Reporting Entity of a Fund under the provisions of this chapter, unless explicitly stated otherwise?"
"/n/nPassage: Where an obligation applies to a Reporting Entity of a Fund under a provision of this chapter, except where expressly provided otherwise, the Governing Body of the Listed Fund must ensure compliance with that obligation."
"/n/nYour response should read:"
"/n/nThe responsibility for ensuring compliance with the obligations that apply to the Reporting Entity of a Fund under the provisions of this chapter lies with the Governing Body of the Listed Fund. This is explicitly stated in the passage, which indicates that unless otherwise specified, it is the Governing Body that must ensure adherence to these obligations.")

In [16]:
jobs = []

# Load the test dataset
with open("ObliQADataset/ObliQA_test.json") as f:
    data = json.load(f)  # Load the JSON file
    
    # For each question:
    for e in tqdm(data):  # tqdm adds a progress bar
        query = e['Question']  # Extract the actual question
        question_id = e["QuestionID"] # Extract the question id

        retrieved_passages = extract_passages(question_id)

        (system_prompt, user_prompt) = build_prompt(query, retrieved_passages, system_prompt_one_shot)
        
        jobs.append({
            "custom_id": question_id,
            "method": "POST",
            "url": "/chat/completions",
            "body": {
                "model": "gpt-4o",
                "messages": [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                "temperature": 0,
                "frequency_penalty": 0.0,
                "presence_penalty": 0.0,
                "max_tokens": 1000,
            }
        })
        
batch_job = queue_batch_summarization_job(jobs)

100%|████████████████████████████████████| 2786/2786 [00:00<00:00, 92037.29it/s]


In [17]:
# At this point the result has been downloaded offline and uploaded to the data folder
answers = []

with open("data/batch_result_4o.jsonl") as f:
    # Parse each line of the file as a JSON
    results = [json.loads(line) for line in f]
    
    for result in results:
        # For each result, create an entry in answers array to later create the output file
        question_id = result["custom_id"]
        # Since this function is deterministic, we can just call it again
        # Clearly, there's an optimization we can do to avoid calling this twice
        retrieved_passages = extract_passages(question_id) 
        answer = result["response"]["body"]["choices"][0]["message"]["content"]
        
        answers.append({
            "QuestionID": question_id,
            "RetrievedPassages": retrieved_passages,
            "Answer": answer
        })
        
# Save the results as a JSON file
with open("data/answers-4o-new.json", "w") as f:
    json.dump(answers, f, indent=2)

## Groq API - Regular deployment

Third, we use Groq's API to synthesize the retrieved passages for each question using `llama-3.1-70b-versatile`. We leverage Groq's high-performance infrastructure to process queries with minimal latency.

In [None]:
# Standard library imports
import json
import os
import time
import asyncio
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from collections import defaultdict
from contextlib import asynccontextmanager

# Third-party imports
from tqdm.asyncio import tqdm_asyncio
from dotenv import load_dotenv
from groq import Groq
import backoff
import nest_asyncio

# Configure minimal logging
logging.basicConfig(
    level=logging.WARNING,
    format='%(message)s'
)
logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.WARNING)

@dataclass
class Question:
    """Represents a question with its ID and content"""
    id: str
    text: str

@dataclass
class ProcessedAnswer:
    """Represents a processed answer with metadata"""
    question_id: str
    retrieved_passages: List[str]
    answer: str
    error: Optional[str] = None

class GroqRateLimiter:
    def __init__(self):
        # Rate limits
        self.tokens_per_minute = 30000
        self.requests_per_minute = 1000
        self.requests_per_day = 50000
        
        # Window durations
        self.minute_window = timedelta(minutes=1)
        self.day_window = timedelta(days=1)
        
        # Track requests with timestamps and token counts
        self.requests = []  # List of (timestamp, tokens) tuples
        self.lock = asyncio.Lock()

    async def wait_for_tokens(self, tokens_needed: int):
        async with self.lock:
            while True:
                now = datetime.now()
                minute_start = now - self.minute_window
                day_start = now - self.day_window
                
                # Remove expired requests
                self.requests = [(ts, tokens) for ts, tokens in self.requests 
                               if ts > day_start]
                
                # Calculate current usage
                minute_requests = [(ts, tokens) for ts, tokens in self.requests 
                                 if ts > minute_start]
                
                minute_request_count = len(minute_requests)
                day_request_count = len(self.requests)
                minute_token_usage = sum(tokens for _, tokens in minute_requests)
                
                # Check all limits
                if (minute_token_usage + tokens_needed <= self.tokens_per_minute and
                    minute_request_count < self.requests_per_minute and
                    day_request_count < self.requests_per_day):
                    # Add new request
                    self.requests.append((now, tokens_needed))
                    return True
                
                # Calculate wait time based on the most restrictive limit
                wait_times = []
                
                # Token limit check
                if minute_token_usage + tokens_needed > self.tokens_per_minute and minute_requests:
                    oldest_in_minute = min(ts for ts, _ in minute_requests)
                    wait_times.append((oldest_in_minute + self.minute_window - now).total_seconds())
                
                # Requests per minute check
                if minute_request_count >= self.requests_per_minute and minute_requests:
                    oldest_in_minute = min(ts for ts, _ in minute_requests)
                    wait_times.append((oldest_in_minute + self.minute_window - now).total_seconds())
                
                # Requests per day check
                if day_request_count >= self.requests_per_day and self.requests:
                    oldest_in_day = min(ts for ts, _ in self.requests)
                    wait_times.append((oldest_in_day + self.day_window - now).total_seconds())
                
                # Wait for the shortest required time
                if wait_times:
                    wait_time = max(0.1, min(wait_times))  # Ensure minimum wait of 0.1s
                    print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...")
                    await asyncio.sleep(wait_time)
                else:
                    # If no wait times calculated but still hitting limits, wait a small amount
                    await asyncio.sleep(0.1)

class GroqProcessor:
    """Handles processing of questions using Groq's API with rate limiting"""
    
    def __init__(self):
        self.groq_client = self._initialize_groq()
        self.rate_limiter = GroqRateLimiter()
        
    def _initialize_groq(self) -> Groq:
        """Initialize Groq client with error handling"""
        groq_api_key = os.getenv('GROQ_API_KEY')
        if not groq_api_key:
            raise ValueError("GROQ_API_KEY environment variable is not defined")
        return Groq(api_key=groq_api_key)

    @staticmethod
    def extract_passages(question_id: str, passages: dict, rankings_dict: dict) -> list[str]:
        """Extract relevant passages for a question"""
        retrieved_passages = []
        should_stop = False
        
        for i in range(len(rankings_dict[question_id])):
            if should_stop or len(retrieved_passages) == 10:
                break
                
            if len(retrieved_passages) == 0:
                retrieved_passages.append(rankings_dict[question_id][i]["doc"])
                continue
                    
            if i < len(rankings_dict[question_id]) - 1 and rankings_dict[question_id][i]["score"] - rankings_dict[question_id][i+1]["score"] > 0.1:
                    should_stop = True

            if rankings_dict[question_id][i]["score"] < 0.72:
                break

            retrieved_passages.append(rankings_dict[question_id][i]["doc"])
            
        return [passages[doc] for doc in retrieved_passages]

    @backoff.on_exception(
        backoff.expo,
        Exception,
        max_tries=5,
        giveup=lambda e: "rate limit" not in str(e).lower()
    )
    async def process_question(self, question: str, passages: List[str]) -> str:
        """Process a single question with rate limiting"""
        system_prompt = """As a regulatory compliance assistant, analyze the provided passages
        and answer the accompanying question. Synthesize information from all passages, which
        are presented in order of relevance.

        Key Requirements:
        1. Extract and prioritize mandatory requirements ("shall" statements)
        2. Distinguish between required vs recommended practices
        3. Identify specific deadlines and documentation needs
        4. Use ONLY information from provided passages - no external knowledge
        5. Integrate all obligations and requirements to ensure consistency

        Structure your response as follows:
        1. Brief executive summary (2-3 sentences)
        2. Core requirements with citations [Section XX]
        3. Implementation steps and timeline
        4. Documentation requirements
           - Required reports
           - Retention periods
        5. Special considerations
           - Exemptions
           - Jurisdictional variations
           - Implementation challenges
           - Potential contradictions or conflicts

        Guidelines:
        - Keep responses under 500 words unless complexity requires more
        - Use bullet points for clarity
        - Highlight critical deadlines
        - Flag any ambiguities or conflicts between requirements
        - Process passages in order of presentation
        - Ensure full integration of all relevant requirements
        - Maintain consistency across all recommendations
        """

        user_prompt = f"Question: {question}\n\nPassages:\n\n" + "\n\n".join(passages)
        
        # Estimate tokens (rough approximation)
        estimated_tokens = len(system_prompt.split()) + len(user_prompt.split()) + 800
        
        # Wait for available tokens
        await self.rate_limiter.wait_for_tokens(estimated_tokens)
        
        try:
            completion = await asyncio.to_thread(
                self.groq_client.chat.completions.create,
                model="llama-3.1-70b-versatile",
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=0.25,
                max_tokens=800,
                top_p=1,
                stream=False
            )
            return completion.choices[0].message.content
            
        except Exception as e:
            if "rate limit" in str(e).lower():
                retry_after = self._extract_retry_after(str(e))
                print(f"Rate limit hit, waiting {retry_after} seconds...")
                await asyncio.sleep(retry_after + 1)
                return await self.process_question(question, passages)
            raise

    @staticmethod
    def _extract_retry_after(error_message: str) -> float:
        """Extract retry-after time from error message"""
        try:
            if "try again in" in error_message:
                time_str = error_message.split("try again in")[1].split("s")[0].strip()
                if "m" in time_str:
                    minutes, seconds = time_str.split("m")
                    return float(minutes) * 60 + float(seconds)
                return float(time_str)
        except Exception:
            pass
        return 60  # Default to 60 seconds if we can't parse the time

async def main():
    # Load environment variables
    load_dotenv()
    
    # Initialize processor
    processor = GroqProcessor()
    answers = []
    
    try:
        # Load necessary data
        print("Loading passages...")
        passages = {}
        for i in range(1, 41):
            with open(os.path.join("ObliQADataset/StructuredRegulatoryDocuments", f"{i}.json")) as f:
                doc = json.load(f)
                for psg in doc:
                    passages[psg["ID"]] = psg["Passage"]

        print("Loading rankings...")
        rankings_dict = defaultdict(list)
        with open('data/rankings_hybrid.trec', 'r') as f:
            for line in f:
                parts = line.strip().split()
                rankings_dict[parts[0]].append({
                    'doc': parts[2],
                    'score': float(parts[4])
                })

        # Process questions
        print("Processing questions...")
        with open("ObliQADataset/ObliQA_test.json") as f:
            questions = json.load(f)
            
            for q in tqdm_asyncio(questions):
                try:
                    retrieved_passages = processor.extract_passages(
                        q["QuestionID"], 
                        passages, 
                        rankings_dict
                    )
                    
                    answer = await processor.process_question(
                        q["Question"], 
                        retrieved_passages
                    )
                    
                    answers.append({
                        "QuestionID": q["QuestionID"],
                        "RetrievedPassages": retrieved_passages,
                        "Answer": answer
                    })
                    
                    # Save progress every 10 questions
                    if len(answers) % 10 == 0:
                        with open("data/answers-llama3.1.json", "w") as f:
                            json.dump(answers, f, indent=2)
                            
                except Exception as e:
                    print(f"Error processing question {q['QuestionID']}: {e}")
                    # Save progress on error
                    if answers:
                        with open("data/answers-llama3.1-partial.json", "w") as f:
                            json.dump(answers, f, indent=2)

        # Save final results
        print("Saving final results...")
        with open("data/answers-llama3.1.json", "w") as f:
            json.dump(answers, f, indent=2)
            
        print("Processing complete!")
        
    except Exception as e:
        print(f"Fatal error during processing: {e}")
        # Save partial results if available
        if answers:
            with open("data/answers-llama3.1-partial.json", "w") as f:
                json.dump(answers, f, indent=2)
        raise

# Run the processor
nest_asyncio.apply()
await main()

## Results Evaluation

To evaluate and compare the results obtained from our three different processing methods:
1. Standard Azure OpenAI deployment with GPT-3.5-Turbo
2. Batch Azure OpenAI deployment with GPT-4O-Mini
3. Groq deployment with Llama-3.1-70B-Versatile

Run the following scripts using the RePASs virtual environment to evaluate each model's performance. Make sure you have activated the correct environment before running these commands.

In [15]:
## Script to evaluate the results. Results are placed in /RePASs/data/hybrid or /RePASs/data/hybrid-4o
## These scripts must be run using the virtual env in RePASs

#python scripts/evaluate_model.py --input_file ./../data/answers.json --group_method_name hybrid
#python scripts/evaluate_model.py --input_file ./../data/answers-4o.json --group_method_name hybrid-4o
#python scripts/evaluate_model.py --input_file ./../data/answers-llama3.1.json --group_method_name hybrid-llama

## Unseen questions

In [18]:
rankings_dict_unseen = defaultdict(list) # Maps a question to the relevant passage and its corresponding ranking score

# Load the rankings file in memory
with open('data/rankings_hybrid_unseen_test.trec', 'r') as f:
    # File format: QuestionID Q0 DocumentID Rank Score Method
    for line in f:
        parts = line.strip().split()
        question_id = parts[0]
        document_id = parts[2]
        rank = int(parts[3])
        score = float(parts[4])
        rankings_dict_unseen[question_id].append({
            'doc': document_id,
            'score': score
        })

In [22]:
jobs = []

# Load the test dataset
with open("ObliQADataset/RIRAG_Unseen_Questions.json") as f:
    data = json.load(f)  # Load the JSON file
    
    # For each question:
    for e in tqdm(data):  # tqdm adds a progress bar
        query = e['Question']  # Extract the actual question
        question_id = e["QuestionID"] # Extract the question id

        retrieved_passages = extract_passages(question_id, rankings_dict_unseen)

        (system_prompt, user_prompt) = build_prompt(query, retrieved_passages, system_prompt_one_shot)
        
        jobs.append({
            "custom_id": question_id,
            "method": "POST",
            "url": "/chat/completions",
            "body": {
                "model": "gpt-35-turbo-2",
                "messages": [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                "temperature": 0,
                "frequency_penalty": 0.0,
                "presence_penalty": 0.0,
                "max_tokens": 1000,
            }
        })
        
batch_job = queue_batch_summarization_job(jobs)

100%|██████████████████████████████████████| 446/446 [00:00<00:00, 56871.05it/s]


In [23]:
questions = {}

# Load the test dataset
with open("ObliQADataset/RIRAG_Unseen_Questions.json") as f:
    data = json.load(f)  # Load the JSON file
    
    # For each question:
    for e in tqdm(data):  # tqdm adds a progress bar
        query = e['Question']  # Extract the actual question
        question_id = e["QuestionID"] # Extract the question id
        
        questions[question_id] = query

100%|█████████████████████████████████████| 446/446 [00:00<00:00, 955881.24it/s]


In [25]:
# At this point the result has been downloaded offline and uploaded to the data folder
answers = []

with open("data/batch_result_unseen.jsonl") as f:
    # Parse each line of the file as a JSON
    results = [json.loads(line) for line in f]
    
    for result in results:
        # For each result, create an entry in answers array to later create the output file
        question_id = result["custom_id"]
        # Since this function is deterministic, we can just call it again
        # Clearly, there's an optimization we can do to avoid calling this twice
        retrieved_passages = extract_passages(question_id, rankings_dict_unseen) 
        answer = result["response"]["body"]["choices"][0]["message"]["content"]
        
        answers.append({
            "QuestionID": question_id,
            "Question": questions[question_id],
            "RetrievedPassages": retrieved_passages,
            "Answer": answer
        })
        
# Save the results as a JSON file
with open("data/answers-unseen.json", "w") as f:
    json.dump(answers, f, indent=2)