In [None]:
import pyterrier as pt

from dotenv import dotenv_values
from multiprocessing import Manager
from tqdm import tqdm
from tabulate import tabulate

from tqdm import tqdm
import concurrent.futures

from tabulate import tabulate
import pandas as pd
from sqlalchemy import create_engine

from sqlalchemy.orm import Session

from langchain import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_openai import ChatOpenAI

import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from tqdm import tqdm
from tabulate import tabulate
import pandas as pd

import os
import json
import glob

if not pt.started():
    pt.init()

manager = Manager()

db_vals = dotenv_values("/workspaces/CORD19_Plus/.env")
from cord19_plus.data_model.model import Table

In [None]:
dataset = pt.get_dataset('irds:cord19/fulltext/trec-covid')
topics = dataset.get_topics()
qrels = dataset.get_qrels()

In [None]:
#define tables to label
engine = create_engine(f"postgresql+psycopg2://{db_vals['USER']}:{db_vals['PASSWORD']}@{db_vals['ADDRESS']}:{db_vals['PORT']}/{db_vals['DB_FINAL']}", echo=False)
session = Session(engine)

result = session.query(Table)
result_dict = [{"docno" : str(e.ir_tab_id),
                "ir_id" : e.ir_id, 
                "pm_content" : e.pm_content, 
                "header" : e.header, 
                "content" : e.content,
                "caption" : e.caption,
                "references" : e.references,
               } for e in result]


In [None]:
ids = [res['docno'] for res in result_dict]
len(set(ids))

In [None]:
prompt_paths = sorted(glob.glob("/workspaces/CORD19_Plus/data/prompts/*"), key=lambda x: int(x.split("_")[-1]))

In [None]:
prompt_paths

In [None]:
prompts = []

for path in prompt_paths:
    with open(path, 'r') as file:
        prompts.append(file.read()) 

In [None]:
prompts

In [None]:
#temperature = 0 to suppress creativity

prompt_version = 5
prompt_t = prompts[prompt_version-1]


model_name = "gpt-4o-mini"
#model_name = "gpt-4o"

prompt= PromptTemplate(template=prompt_t, input_variables=["QUERY", 
                                                                  "QUESTION",
                                                                  "NARRATIVE",
                                                                  "CONTENT",
                                                                  "CAPTION",
                                                                  "REFERENCES"])
model = ChatOpenAI(temperature=0, model=model_name)
output_parser = StrOutputParser()

#chain = prompt | model 
chain = prompt | model | output_parser

manager = Manager()

In [None]:

# Assuming 'topics' is a pandas DataFrame with 'qid', 'title', 'description', and 'narrative' columns
# 'result_dict' is a list of dictionaries with keys like 'docno', 'content', 'header', 'caption', 'references'
# 'pool' is a pandas DataFrame with 'qid' and 'docno' columns
# 'chain' is an object with an 'invoke' method that processes input_data

def handle_request(topic, table, shared_qrels, lock):
    """
    Processes a single document within a topic and updates the shared_qrels dictionary.

    Args:
        topic (pd.Series): A row from the topics DataFrame representing a single topic.
        table (dict): A dictionary containing document information.
        shared_qrels (dict): A shared dictionary to store results.
        lock (threading.Lock): A lock to ensure thread-safe updates to shared_qrels.
    """
    input_data = {
        "QUERY": topic['title'],
        "QUESTION": topic['description'],
        "NARRATIVE": topic['narrative'],
        "CONTENT": tabulate(table['content'], headers=table['header'], tablefmt="github"),
        "CAPTION": table['caption'],
        "REFERENCES": " | ".join(table['references'])
    }

    try:
        # Invoke the chain's method to process the input_data
        res = chain.invoke(input_data)

        # Safely update the shared_qrels dictionary
        with lock:
            shared_qrels[topic['qid']].append((table['docno'], int(res)))
    except Exception as e:
        # Handle exceptions gracefully, optionally logging them
        print(f"Error processing docno {table['docno']} in topic {topic['qid']}: {e}")

def process_topic(topic, result_dict, shared_qrels, pool, lock, results_dir):
    """
    Processes all documents within a single topic using a ThreadPoolExecutor and saves the results.

    Args:
        topic (pd.Series): A row from the topics DataFrame representing a single topic.
        result_dict (list): A list of document dictionaries to process.
        shared_qrels (dict): A shared dictionary to store results.
        pool (pd.DataFrame): A DataFrame containing 'qid' and 'docno' columns for filtering.
        lock (threading.Lock): A lock to ensure thread-safe updates to shared_qrels.
        results_dir (str): Directory path where results are saved.
    """
    qid = topic['qid']
    result_file = os.path.join(results_dir, f"qid_{qid}.json")

    # Check if the result file already exists to avoid reprocessing
    if os.path.exists(result_file):
        print(f"Results for qid {qid} already exist. Skipping processing.")
        return

    # Initialize the list for this topic
    with lock:
        shared_qrels[qid] = []

    # Filter documents based on the pool if provided
    if not pool.empty:
        pool_topic_ids = pool[pool['qid'] == int(qid)]['docno'].to_list()
        tables_to_process = [table for table in result_dict if table['docno'] in pool_topic_ids]
    else:
        tables_to_process = result_dict

    if not tables_to_process:
        # No documents to process for this topic
        print(f"No documents to process for qid {qid}.")
        return

    # Use a ThreadPoolExecutor to process documents in parallel
    with ThreadPoolExecutor(max_workers=5) as executor:  # Adjust max_workers as needed
        futures = [
            executor.submit(handle_request, topic, table, shared_qrels, lock)
            for table in tables_to_process
        ]

        # Ensure all document processing tasks complete
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()  # Retrieve result to catch exceptions
            except Exception as e:
                # Exceptions are already handled in handle_request; this is optional
                pass

    # After processing all documents for the topic, save the results to a file
    try:
        with lock:
            # Prepare data to save
            data_to_save = shared_qrels[qid]
        with open(result_file, 'w', encoding='utf-8') as f:
            json.dump(data_to_save, f, ensure_ascii=False, indent=4)
        print(f"Results for qid {qid} saved successfully.")
    except Exception as e:
        print(f"Error saving results for qid {qid}: {e}")

def load_existing_results(results_dir):
    """
    Loads existing results from the results directory into a dictionary.

    Args:
        results_dir (str): Directory path where results are saved.

    Returns:
        dict: A dictionary mapping qid to a list of (docno, score) tuples.
    """
    shared_qrels = {}
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
        return shared_qrels

    for filename in os.listdir(results_dir):
        if filename.startswith("qid_") and filename.endswith(".json"):
            qid_str = filename[4:-5]  # Extract qid from filename
            try:
                qid = int(qid_str)
                file_path = os.path.join(results_dir, filename)
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                shared_qrels[qid] = [tuple(item) for item in data]
            except Exception as e:
                print(f"Error loading {filename}: {e}")
    return shared_qrels

def label_full_pool(pool, topics, result_dict, results_dir):
    """
    Processes all topics and their associated documents, tracking overall progress with a tqdm bar.
    Saves results after each topic is processed and allows resuming from interruptions.

    Args:
        pool (pd.DataFrame): A DataFrame containing 'qid' and 'docno' columns for filtering.
        topics (pd.DataFrame): A DataFrame containing topic information with 'qid', 'title', 'description', and 'narrative'.
        result_dict (list): A list of document dictionaries to process.
        results_dir (str): Directory path where results are saved. Defaults to 'results'.

    Returns:
        dict: A dictionary mapping each qid to a list of (docno, score) tuples.
    """
    shared_qrels = load_existing_results(results_dir)
    lock = Lock()  # To ensure thread-safe operations on shared_qrels

    num_topics = len(topics)
    processed_topics = len(shared_qrels)
    remaining_topics = num_topics - processed_topics

    print(f"Total topics: {num_topics}")
    print(f"Already processed topics: {processed_topics}")
    print(f"Remaining topics to process: {remaining_topics}")

    # Initialize the overall progress bar for topics
    with tqdm(total=num_topics, desc="Overall Progress") as overall_pbar:
        # Update the progress bar for already processed topics
        overall_pbar.update(processed_topics)

        with ThreadPoolExecutor(max_workers=10) as executor:  # Adjust max_workers based on your system
            # Submit all topic processing tasks
            futures = [
                executor.submit(
                    process_topic,
                    topic,
                    result_dict,
                    shared_qrels,
                    pool,
                    lock,
                    results_dir
                )
                for _, topic in topics.iterrows()
            ]

            # Iterate over the completed futures to update the progress bar
            for future in concurrent.futures.as_completed(futures):
                try:
                    future.result()  # Retrieve result to catch exceptions
                except Exception as e:
                    # Handle exceptions gracefully, optionally logging them
                    print(f"Error processing topic: {e}")
                finally:
                    overall_pbar.update(1)  # Update the overall progress bar after each topic

    return shared_qrels

In [None]:
pool = pd.read_json("/workspaces/CORD19_Plus/retrieval/pool.json")
pool

In [None]:
top5_pool = pool[pool['rank'] <= 5]

In [None]:
top5_pool

In [None]:
pool = pd.read_json("/workspaces/CORD19_Plus/retrieval/pool.json")
pool_qrels = label_full_pool(pool, topics, result_dict=result_dict, results_dir=f'/workspaces/CORD19_Plus/data/labeling/labeling_results{prompt_version}-final')

In [None]:
rows = []
for qid, docs in pool_qrels.items():
    for docno, label in docs:
        rows.append({'qid': qid, 'docno': docno, 'label': label})

table_qrels = pd.DataFrame(rows).reset_index(drop=True)

In [None]:
table_qrels

In [None]:
qrels_path = f"/workspaces/CORD19_Plus/data/labeling/table_pool_qrels{prompt_version}-final.json"
table_qrels.to_json(qrels_path, orient="records")