In [49]:
import pandas as pd
import glob
from collections import defaultdict
import openai
import requests

from tqdm import tqdm
import gzip
import json
import tarfile
from random import randrange
import shutil

## Download Passages Collection

### MSMARCO Collection v1

In [7]:
# ** Should add a code to check to downlaod if and only if the file is not exsit

# url = 'https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz'
# target_path = 'msmarco-collections/collection.tar'

# response = requests.get(url, stream=True)
# if response.status_code == 200:
#     with open(target_path, 'wb') as f:
#         f.write(response.raw.read())

In [8]:
# with tarfile.open('msmarco-collections/collection.tar', 'r') as tar:
#         tar.extractall()

## Loading TREC DL Files

In [13]:
YEAR = 2022
nist_qrel_path = f"./TREC-DL-{YEAR}/qrels-pass.txt"
results_path = f"./TREC-DL-{YEAR}/passages-runs/"
test_queries_path = f"./TREC-DL-{YEAR}/test-queries.tsv"

if YEAR in [2019, 2020, 2021]:
    results_sep = '\t'
    passages_path = "msmarco-collections/collection.tsv"
else:
    results_sep = ' '
    passages_path =  "msmarco-collections/msmarco_v2_passage"

## Loading Files

In [14]:
# loading nist qrel file
nist_qrel = pd.read_csv(nist_qrel_path, sep=' ', header=None, names=['qid', 'Q0', 'docid', 'rel'])
nist_qrel.shape

(386416, 4)

In [16]:
# number of passages per query
nist_qrel.groupby('qid')['docid'].agg('count').reset_index()

Unnamed: 0,qid,docid
0,2000511,1492
1,2000719,42710
2,2001532,369
3,2001908,276
4,2001975,406
...,...,...
71,2055480,329
72,2055634,292
73,2055795,387
74,2056158,535


In [17]:
# maximum number of passages for the query
max(nist_qrel.groupby('qid')['docid'].agg('count').reset_index()['docid'])

49627

In [19]:
# number of judged queries by NIST
nist_judged_qids = set(nist_qrel['qid'])
len(nist_judged_qids)

76

## Creating Depth-10 Pooling

In [20]:
# loading all submissions
run_df_list = []
for infile in tqdm(glob.glob(f'{results_path}/*')):
    run_df = pd.read_csv(infile, sep='\s+', header=None, names=['qid', 'Q0', 'docid', 'rank', 'score', 'run_id'])
    run_df_list.append(run_df)
 
all_submissions_df = pd.concat(run_df_list)
all_submissions_df

100%|██████████| 100/100 [00:03<00:00, 27.31it/s]


Unnamed: 0,qid,Q0,docid,rank,score,run_id
0,1006826,Q0,msmarco_passage_07_865649873,1,0.094518,6systems
1,1006826,Q0,msmarco_passage_44_338456978,2,0.092599,6systems
2,1006826,Q0,msmarco_passage_07_865676760,3,0.091848,6systems
3,1006826,Q0,msmarco_passage_17_149088917,4,0.089694,6systems
4,1006826,Q0,msmarco_passage_07_865666439,5,0.089393,6systems
...,...,...,...,...,...,...
49995,997700,Q0,msmarco_passage_20_141207184,96,0.492810,SPLADE_ENSEMBLE_PP
49996,997700,Q0,msmarco_passage_04_584378445,97,0.488177,SPLADE_ENSEMBLE_PP
49997,997700,Q0,msmarco_passage_39_570179944,98,0.487391,SPLADE_ENSEMBLE_PP
49998,997700,Q0,msmarco_passage_20_142671399,99,0.487260,SPLADE_ENSEMBLE_PP


In [21]:
# removing the queries that are judged by nist assessors
all_submissions_df = all_submissions_df[~all_submissions_df['qid'].isin(nist_judged_qids)]
all_submissions_df.shape

(4133806, 6)

In [22]:
# no. of submission and no. of unjudged quereis
print(len(set(all_submissions_df['run_id'])))
print(len(set(all_submissions_df['qid'])))

100
424


In [23]:
# just to be sure that judged queries are not included.
set(all_submissions_df['qid']).intersection(nist_judged_qids)

set()

In [24]:
# creating 10-depth pool based on the submissions rank of passages for each query
depth_pool_samples = all_submissions_df[all_submissions_df['rank'].between(1, 10)]
depth_pool_samples

Unnamed: 0,qid,Q0,docid,rank,score,run_id
0,1006826,Q0,msmarco_passage_07_865649873,1,0.094518,6systems
1,1006826,Q0,msmarco_passage_44_338456978,2,0.092599,6systems
2,1006826,Q0,msmarco_passage_07_865676760,3,0.091848,6systems
3,1006826,Q0,msmarco_passage_17_149088917,4,0.089694,6systems
4,1006826,Q0,msmarco_passage_07_865666439,5,0.089393,6systems
...,...,...,...,...,...,...
49905,997700,Q0,msmarco_passage_18_538727342,6,1.695422,SPLADE_ENSEMBLE_PP
49906,997700,Q0,msmarco_passage_39_454078645,7,1.659401,SPLADE_ENSEMBLE_PP
49907,997700,Q0,msmarco_passage_58_304621249,8,1.649428,SPLADE_ENSEMBLE_PP
49908,997700,Q0,msmarco_passage_58_304622289,9,1.647338,SPLADE_ENSEMBLE_PP


In [25]:
# a dict from qid to docids
qid_to_docids = defaultdict(set)

for eachsample in depth_pool_samples.itertuples(index=True):
    qid_to_docids[eachsample.qid].add(eachsample.docid)

In [29]:
qid_to_num_passages = {k:len(v) for k, v in qid_to_docids.items()}

In [35]:
# number of unique passages/docids
docids = set([item for sublist in qid_to_docids.values() for item in sublist])
len(docids)

56443

In [37]:
# no. of queries and passages
print(len(qid_to_num_passages.keys()))
print(sum(qid_to_num_passages.values()))

424
56475


### Reading Passages/Dcouments

In [40]:
# def read_bundles(bundlenum):
def read_bundles_v1():
    collection = open(passages_path, 'r').readlines()
    for eachline in collection:
        docid, passage = eachline.split('\t')
        docid = int(docid)
        if docid in docids:
            passages_bundles[docid] = passage

In [41]:
def read_bundles_v2(bundlenum):
     with gzip.open(f'{passages_path}/msmarco_passage_{bundlenum}.gz','r') as fpassage:
          for passage in fpassage:
            json_passage = json.loads(passage.decode('utf8'))
            if json_passage['pid'] in docids:
                passages_bundles[json_passage['pid']] = json_passage['passage']

In [None]:
passages_bundles = {}

if YEAR in [2019, 2020, 2021]:
    # read TREC 2019, 2020 passagess
    read_bundles_v1()
else:
    for bundlenum in tqdm(range(0, 70)):
        if bundlenum < 10:
            bundlenum = f'0{str(bundlenum)}'
        read_bundles_v2(bundlenum=bundlenum)   

In [42]:
test_queries = pd.read_csv(test_queries_path, sep='\t', header=None, names=['qid', 'query'])
test_queries.head()

Unnamed: 0,qid,query
0,588,1099 b cost basis i sell specific shares
1,9141,a boiled egg is how many calories
2,43905,average single person income
3,49712,behavior define
4,58376,calculate btu per natural gas flow


In [44]:
# dict: qid to query
queries = dict(zip(test_queries['qid'], test_queries['query']))
len(queries)

500

In [45]:
def get_gpt_response(system_context: str, text: str):
    """
    Generate a chat completion using OpenAI's chat completion API.
 
    :param system_context: some context and/or instructions to the model"
    :param text: user message (aka prompt)
    :return:
    """
    openai.api_type = "API-Type"
    openai.api_version = 'API-Version'
    openai.api_base = 'API-Base-URL'
    openai.api_key = 'Your-Key'
 
    response = openai.ChatCompletion.create(
        engine="gpt-4-32k", # The deployment name you chose when you deployed the ChatGPT or GPT-4 model.
        messages=[
            {
                "role": "system",
                "content": """{}""".format(system_context)
            },
            {
                "role": "user",
                "content": text
            },
        ],
        temperature=0,
        top_p=1,
        frequency_penalty=0.5,
        presence_penalty=0,
    )
 
    return response['choices'][0]['message']['content']

In [46]:
def auto_eval(system_context, user_prompt):

    system_context = system_context
    user_prompt = user_prompt
    
    try:
        gpt_4_res = get_gpt_response(system_context=system_context, text=user_prompt)
    except:
        gpt_4_res = "ERROR"
        # print(f"gpt-4 ERROR: {e}")

    return gpt_4_res

In [47]:
context = """You are a search quality rater evaluating the relevance of passages. Given a query and a passage, you must provide a score on an integer scale of 0 to 3 with the following meanings:

3 = Perfectly relevant: The passage is dedicated to the query and contains the exact answer.
2 = Highly relevant: The passage has some answer for the query, but the answer may be a bit unclear, or hidden amongst extraneous information.
1 = Related: The passage seems related to the query but does not answer it.
0 = Irrelevant: The passage has nothing to do with the query

Assume that you are writing an answer to the query. If the passage seems to be related to the query but does not include any answer to the query, mark it 1. If you would use any of the information contained in the passage in such an asnwer, mark it 2. If the passage is primarily about the query, or contains vital information about the topic, mark it 3. Otherwise, mark it 0.
"""

In [48]:
def get_user_prompt(query, passage):
    return f"""Query
    A person has typed [{query}] into a search engine.
    
    Result
    Consider the following passage.
    —BEGIN Passage CONTENT—
    {passage}
    —END Passage CONTENT—
    
    Instructions
    Consider the underlying intent of the search, and decide on a final score of the relevancy of query to the passage given the context.
    Score:"""

In [None]:
# no need during summarisation:
gpt4_judgments = open(f"all-queries-with-error/gpt4_judgments_all_queries_dl_{YEAR}.txt", 'w')

for qid, docids in tqdm(qid_to_docids.items()):
    for docid in docids:
        score = auto_eval(system_context=context, user_prompt=get_user_prompt(query=queries[qid], passage=passages_bundles[docid]))
        print(score)
        gpt4_judgments.write(f"{qid} 0 {docid} {score}\n")
    
gpt4_judgments.close()

### Summarisation Prompt: Error Cases for the Long Passages

In [50]:
summary_context = """You are a summariser that summarises the passage to make it shorter but as much as relevant to the input passage."""

In [51]:
def get_summary_prompt(passage):
    return f"""Return a very short relevant summary of the give passage:
    Passage: {passage}
    Summary:"""

In [33]:
gpt4_qrel = open("gpt4_judgments_all_queries_dl_2022.txt", 'r').readlines()
rnd_index = open("gpt4_judgments_dl2022_random_index.txt", 'w')

write_counter = 0
indices = set()

with open("gpt4_judgments_dl2022_processed.txt", 'w') as out_file:
    for index, eachline in enumerate(tqdm(gpt4_qrel)):
        qid, q0, docid, score = eachline.split(' ', 3)
        score = score.strip('\n')
        
        if len(score) == 1:
            out_file.write(f"{qid} 0 {docid} {score}\n")
            write_counter += 1
            indices.add(index)

        elif score == "ERROR":
            # passage = passages_bundles[int(docid)]
            passage = passages_bundles[docid]
            res_first = passage[0:len(passage)//2]
            res_second = passage[len(passage)//2 if len(passage)%2 == 0 else ((len(passage)//2)+1):]

            summarised_passage_1 = auto_eval(summary_context, get_summary_prompt(passage=res_first))
            summarised_passage_2 = auto_eval(summary_context, get_summary_prompt(passage=res_second))
            
            if summarised_passage_1 == "" or summarised_passage_2 == "":
                print(f"NoSummarisedPassage - INDEX: {index}")
                rnd_index.write(f"NoSummarisedPassage {index}\n")
                out_file.write(f"{qid} 0 {docid} {randrange(4)}\n")
                write_counter += 1
                indices.add(index)
            else:
                summarised_passage = summarised_passage_1 + ' ' + summarised_passage_2
                new_score = auto_eval(system_context=context, user_prompt=get_user_prompt(query=queries[int(qid)], passage=summarised_passage))
                if new_score == "ERROR":
                    print(f"NoLabelContextLength - INDEX: {index}")
                    rnd_index.write(f"NoLabelContextLength {index}\n")
                    out_file.write(f"{qid} 0 {docid} {randrange(4)}\n")
                    write_counter += 1
                    indices.add(index)
                else:
                    out_file.write(f"{qid} 0 {docid} {new_score}\n")
                    write_counter += 1
                    indices.add(index)
        else:
            print(f"NoConditionApplied. Score: {score}")
            print(len(score))
            out_file.write(f"{qid} 0 {docid} {randrange(4)}\n")
            write_counter += 1
            indices.add(index)

rnd_index.close()
out_file.close()

print(f"The write counter is: {write_counter}")

100%|██████████| 56475/56475 [01:03<00:00, 896.05it/s] 

The write counter is: 56475



