In [14]:
import os
import re
import copy
import random
import collections
import torch
import numpy as np
import pandas as pd
import json
import pickle
import nltk

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset
from transformers import TrainerCallback, AutoTokenizer

os.chdir('/home/s2310409/workspace/coliee-2024/')

def load_data(dir):
    with open(dir, 'r') as fp:
        train_data = json.load(fp)

    data = []
    for key in train_data.keys():
        data.append([key, train_data[key]])

    return pd.DataFrame(data, columns=['source', 'target'])

# BM25 Indexing

In [15]:
def chunking(sentences, window_size=10):
    chunks = []
    for i in range(0, len(sentences) - window_size, window_size//2):
        chunks.append("\n".join(sentences[i:i+window_size]))
    return chunks

with open('dataset/all_data.json') as f:
    all_data_dict = json.load(f)

word_tokenizer = nltk.tokenize.WordPunctTokenizer()
# file_list = sorted(list(all_data_dict.keys()))

mode = 'test'

file_list = [f for f in os.listdir(f'/home/s2310409/workspace/coliee-2024/dataset-2023/task1/{mode}_files') if f.endswith('.txt')]
file_list = [f for f in file_list if f in all_data_dict.keys()]
file_list = sorted(file_list)

processed_file_dict = {}
for file in [f for f in os.listdir("dataset/processed") if not f.startswith('.')]:
    processed_file = f"dataset/processed/{file}"
    with open(processed_file, 'r') as fp:
        processed_document = fp.read()
        processed_file_dict[file] = {
            'sentences': processed_document.split('\n\n'),
            'processed_document': processed_document
        }

chunk_dict = {}
for file in file_list:
    chunks = chunking(processed_file_dict[file]['sentences'])
    for i, chunk in enumerate(chunks):
        if len(chunk) > 0:
            chunk_dict[f"{file}_{i}"] = chunk

use_chunk = False

if use_chunk:
    # bm25 for chunks
    corpus = []
    chunk_list = sorted(list(chunk_dict.keys()))
    for chunk in chunk_list:
        corpus.append(chunk_dict[chunk])
    tokenized_corpus = [word_tokenizer.tokenize(doc) for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)
else:
    # bm25 for whole document
    corpus = []
    prcessed_list = sorted(file_list)
    for file in prcessed_list:
        corpus.append(processed_file_dict[file]['processed_document'])
    tokenized_corpus = [word_tokenizer.tokenize(doc) for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)

# Query with TF-IDF keywords

In [16]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer

stopwords = nltk.corpus.stopwords.words('english')

docs = []
for file in file_list:
    docs.append(processed_file_dict[file]['processed_document'])

count_vec = CountVectorizer(stop_words=stopwords)
word_count_vector = count_vec.fit_transform(docs)


tfidf_transformer=TfidfTransformer(smooth_idf=True,use_idf=True)
tfidf_transformer.fit(word_count_vector)

features = count_vec.get_feature_names_out()

In [17]:
def sort_coo(coo_matrix):
    tuples = zip(coo_matrix.col, coo_matrix.data)
    return sorted(tuples, key=lambda x: (x[1], x[0]), reverse=True)

def extract_topn_from_vector(feature_names, sorted_items, topn=10):
    """get the feature names and tf-idf score of top n items"""
    
    #use only topn items from vector
    sorted_items = sorted_items[:topn]

    score_vals = []
    feature_vals = []
    
    # word index and corresponding tf-idf score
    for idx, score in sorted_items:
        
        #keep track of feature name and its corresponding score
        score_vals.append(round(score, 3))
        feature_vals.append(feature_names[idx])

    #create a tuples of feature,score
    #results = zip(feature_vals,score_vals)
    results= {}
    for idx in range(len(feature_vals)):
        results[feature_vals[idx]]=score_vals[idx]
    
    return results

## Query on chunks

In [18]:
# n_keywords = 25

# def extract_query(doc):
#     tf_idf_vector=tfidf_transformer.transform(count_vec.transform([doc]))
#     sorted_items = sort_coo(tf_idf_vector.tocoo())
#     keywords=extract_topn_from_vector(features,sorted_items, n_keywords)
#     return " ".join(list(keywords.keys()))

# query_dict = {}

# for file in tqdm(file_list):
#     query_dict[file] = extract_query(processed_file_dict[file]['processed_document'])

# n_candidates = 150
# chunk_candidate_dict = {}
# candidate_dict = {}

# for file in tqdm(file_list):
#     query = query_dict[file]
#     tokenized_query = word_tokenizer.tokenize(query)
#     results = bm25.get_scores(tokenized_query)
#     max_ids = np.argsort(results)[-n_candidates:]

#     chunk_candidates = [chunk_list[idx] for idx in max_ids]
#     chunk_candidate_dict[file] = chunk_candidates

#     document_candidates = [chunk.split('_')[0] for chunk in chunk_candidates]
#     candidate_dict[file] = list(set(document_candidates))

# data_df = load_data(f'dataset/json/{mode}.json')

# data_df['chunk_candidates'] = data_df['source'].apply(lambda x: chunk_candidate_dict[x])
# data_df['candidates'] = data_df['source'].apply(lambda x: candidate_dict[x])
# data_df['query'] = data_df['source'].apply(lambda x: query_dict[x])

# # calculate accuracy metrics for BM25 + TF-IDF
# correct = 0
# n_retrived = 0
# n_relevant = 0

# coverages = []

# for index, row in data_df.iterrows():
#     source = row['source']
#     target = row['target']
#     preds = row['candidates']
#     coverages.append(len(preds))
#     n_retrived += len(preds)
#     n_relevant += len(target)
#     for prediction in preds:
#         if prediction in target:
#             correct += 1

# precision = correct / n_retrived
# recall = correct / n_relevant

# print(f"Average # candidates: {np.mean(coverages)}")
# print(f"Precision: {precision}")
# print(f"Recall: {recall}")
# print(f"F1: {2 * precision * recall / (precision + recall)}")

## Query on all documents

In [19]:
n_keywords = 25

def extract_query(doc):
    tf_idf_vector=tfidf_transformer.transform(count_vec.transform([doc]))
    sorted_items = sort_coo(tf_idf_vector.tocoo())
    keywords=extract_topn_from_vector(features,sorted_items, n_keywords)
    return " ".join(list(keywords.keys()))

query_dict = {}

for file in tqdm(file_list):
    query_dict[file] = extract_query(processed_file_dict[file]['processed_document'])

n_candidates = 150
candidate_dicts = {}

for file in tqdm(file_list):
    query = query_dict[file]
    tokenized_query = word_tokenizer.tokenize(query)
    results = bm25.get_scores(tokenized_query)
    max_ids = np.argsort(results)[-n_candidates:]
    document_candidates = [file_list[idx] for idx in max_ids]
    candidate_dicts[file] = list(set(document_candidates))

data_df = load_data(f'dataset/json/{mode}.json')

data_df['candidates'] = data_df['source'].apply(lambda x: candidate_dicts[x])
data_df['query'] = data_df['source'].apply(lambda x: query_dict[x])

# calculate accuracy metrics for BM25 + TF-IDF
correct = 0
n_retrived = 0
n_relevant = 0

coverages = []

for index, row in data_df.iterrows():
    source = row['source']
    target = row['target']
    preds = row['candidates']
    coverages.append(len(preds))
    n_retrived += len(preds)
    n_relevant += len(target)
    for prediction in preds:
        if prediction in target:
            correct += 1

precision = correct / n_retrived
recall = correct / n_relevant

print(f"Average # candidates: {np.mean(coverages)}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1: {2 * precision * recall / (precision + recall)}")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1217/1217 [00:03<00:00, 320.70it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1217/1217 [00:08<00:00, 135.45it/s]


Average # candidates: 150.0
Precision: 0.014566353187042842
Recall: 0.838748495788207
F1: 0.028635401902179496


# Mono T5

In [20]:
data_df

Unnamed: 0,source,target,candidates,query
0,070318.txt,[015076.txt],"[037739.txt, 067598.txt, 082009.txt, 040046.tx...",rpo board guideline adamidis applicant france ...
1,077960.txt,"[009054.txt, 040860.txt]","[044662.txt, 027552.txt, 028766.txt, 019572.tx...",removal child children custody order dunn cour...
2,042319.txt,"[093691.txt, 075956.txt, 084953.txt, 022987.txt]","[044662.txt, 010137.txt, 091415.txt, 048470.tx...",beyer cross affidavit prothonotary examination...
3,041766.txt,[039269.txt],"[017973.txt, 069618.txt, 091415.txt, 014381.tx...",drug clinical nds health data 002 omitted 08 n...
4,077407.txt,[038669.txt],"[007118.txt, 048470.txt, 067598.txt, 056351.tx...",communication 23 privilege counsel litigation ...
...,...,...,...,...
314,085079.txt,"[044669.txt, 003144.txt]","[007118.txt, 025490.txt, 056351.txt, 028766.tx...",cso promotions shephard cst adjudicator commis...
315,031370.txt,"[096341.txt, 060602.txt, 047107.txt, 084522.tx...","[098691.txt, 027552.txt, 082009.txt, 019572.tx...",removal peru applicant irreparable 3d spouse p...
316,085828.txt,"[004301.txt, 074887.txt, 088994.txt]","[029810.txt, 017973.txt, 044930.txt, 098691.tx...",officer applicants india singh riots risk prra...
317,024957.txt,"[015009.txt, 080348.txt]","[049299.txt, 019620.txt, 002842.txt, 053682.tx...",seizure annuity civil 224 code debtor chattels...


In [21]:
def load_summary_data(dir):
    summary_data = {}
    files = os.listdir(dir)
    files = [f for f in files if not f.startswith('.')]
    for file in files:
        f_path = os.path.join(dir, file)
        with open(f_path, 'r') as fp:
            summary_data[file] = fp.read()
    return summary_data

summary_data = load_summary_data('dataset/summarized')

In [22]:
def prompt(document, query):
    return f'##Query: {query} ##Document: {document} ##Relevant:'

In [27]:
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    TrainingArguments
)
ckpt_dir = os.path.join('./train_logs/monot5-large-10k_hns/ckpt/checkpoint-834')
tokenizer = AutoTokenizer.from_pretrained('castorini/monot5-large-msmarco-10k')
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_dir).to('cuda')
model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

In [28]:
e_id = 0
source = data_df.iloc[e_id].source
candidates = data_df.iloc[e_id].candidates
query = data_df.iloc[e_id].query
for candidate in candidates:
    candidate_summary = summary_data['015076.txt']
    text = prompt(document=tokenizer.decode(tokenizer.encode(candidate_summary, add_special_tokens=False, max_length=450, truncation=True)), query=query)
    print(text)
    break

##Query: rpo board guideline adamidis applicant france hearing bias ldk actions refugee apprehension member protection immigration recuse port kosovo cvv officer negative findings inference members mosley ##Document: Pelletier, J. : This is an application under section 82.1 of the Immigration Act, R.S.C. 1985, c. The CRDD found that the applicant was not a credible witness. The applicant argues that the CRDD's finding of implausibility was unreasonable. Application for judicial review of the decision of the Convention Refugee Determination Division, dated December 16, 1999. The application for judicial review of ten applicants were heard together because of certain common issues, one of which was whether the applicants had become refugee sur place. Each of the applicants made a claim before the Convention Refugee Determination Division ("CRDD") on the basis of well-founded fear of persecution of imputed political opinion The CRDD found that there was insufficient objective grounds to f

In [29]:
inputs = tokenizer(text, return_tensors='pt').to('cuda')
outputs = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, max_new_tokens=10)
outputs

GreedySearchEncoderDecoderOutput(sequences=tensor([[   0, 1176,    1]], device='cuda:0'), scores=(tensor([[-31.0093, -17.6671, -17.9352,  ..., -46.2695, -46.2710, -46.3849]],
       device='cuda:0'), tensor([[-127.4821,  -40.6877,  -88.1471,  ..., -174.7162, -175.4172,
         -175.3853]], device='cuda:0')), encoder_attentions=None, encoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, decoder_hidden_states=None)

In [41]:
tokenizer.decode(outputs.sequences[0])

'<pad> false</s>'

In [40]:
outputs.sequences[0]

tensor([   0, 6136,    1], device='cuda:0')

In [37]:
(outputs.scores)

(tensor([[-33.8814, -18.8008, -19.5146,  ..., -48.9525, -48.9488, -49.0627]],
        device='cuda:0'),
 tensor([[-128.0669,  -41.3602,  -88.5412,  ..., -175.1797, -175.8803,
          -175.8492]], device='cuda:0'))

In [38]:
prediction_dict = {}
for e_id in tqdm(range(len(data_df))):
    source = data_df.iloc[e_id].source
    candidates = data_df.iloc[e_id].candidates
    query = data_df.iloc[e_id].query
    prediction_dict[source] = {
        'result':[],
        'raw':{}
    }
    for candidate in candidates:
        candidate_summary = summary_data[candidate]
        text = prompt(document=tokenizer.decode(tokenizer.encode(candidate_summary, add_special_tokens=False, max_length=450, truncation=True)), query=query)
        inputs = tokenizer(text, return_tensors='pt').to('cuda')
        with torch.no_grad():
            outputs = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, max_new_tokens=10)
            prediction_dict[source]['raw'] = {
                'sequences': list(outputs.sequences[0].cpu().detach().numpy()),
                'scores': list([outputs.scores[0].cpu().detach().numpy(), outputs.scores[1].cpu().detach().numpy()])
            }
            decoded_output = tokenizer.decode(outputs.sequences[0])
            if 'true' in decoded_output:
                prediction_dict[source]['result'].append(candidate)

  0%|                                                                                                                        | 0/319 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 319/319 [54:18<00:00, 10.21s/it]


In [39]:
data_df['prediction'] = data_df['source'].apply(lambda x: prediction_dict[x]['result'])

# calculate accuracy metrics for BM25 + TF-IDF
correct = 0
n_retrived = 0
n_relevant = 0

coverages = []

for index, row in data_df.iterrows():
    source = row['source']
    target = row['target']
    preds = row['prediction']
    coverages.append(len(preds))
    n_retrived += len(preds)
    n_relevant += len(target)
    for prediction in preds:
        if prediction in target:
            correct += 1

precision = correct / n_retrived
recall = correct / n_relevant

print(f"Coverage: {np.mean(coverages)/len(file_list)}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1: {2 * precision * recall / (precision + recall)}")

Coverage: 0.08880463032844525
Precision: 0.016011138183083886
Recall: 0.6642599277978339
F1: 0.0312685869657575


In [44]:
import sys
sys.path.append('/home/s2310409/workspace/coliee-2024/modules/pygaggle')
from utils.dataset import build_dataset

test_dataset = build_dataset(mode='test')

2024-01-11 14:32:53 [INFO] env: 
Using override env var JVM_PATH (/home/s2310409/jdk/lib/server/libjvm.so) to load libjvm.
Please report your system information (os version, java
version, etc), and the path that works for you, to the
PyJNIus project, at https://github.com/kivy/pyjnius/issues.
so we can improve the automatic discovery.

2024-01-11 14:32:53 [INFO] loader: Loading faiss with AVX2 support.
2024-01-11 14:32:53 [INFO] loader: Successfully loaded faiss with AVX2 support.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1217/1217 [00:03<00:00, 320.00it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1217/1217 [00:08<00:00, 141.28it/s]


In [51]:
import jsonlines
import subprocess

def build_df(mode='test'):
    with open('dataset/all_data.json') as f:
        all_data_dict = json.load(f)

    word_tokenizer = nltk.tokenize.WordPunctTokenizer()
    # file_list = sorted(list(all_data_dict.keys()))

    file_list = [f for f in os.listdir(f'/home/s2310409/workspace/coliee-2024/dataset-2023/task1/{mode}_files') if f.endswith('.txt')]
    file_list = [f for f in file_list if f in all_data_dict.keys()]
    file_list = sorted(file_list)

    processed_file_dict = {}
    for file in [f for f in os.listdir("dataset/processed") if not f.startswith('.')]:
        processed_file = f"dataset/processed/{file}"
        with open(processed_file, 'r') as fp:
            processed_document = fp.read()
            processed_file_dict[file] = {
                'sentences': processed_document.split('\n\n'),
                'processed_document': processed_document
            }

    chunk_dict = {}
    for file in file_list:
        chunks = chunking(processed_file_dict[file]['sentences'])
        for i, chunk in enumerate(chunks):
            if len(chunk) > 0:
                chunk_dict[f"{file}_{i}"] = chunk

    if use_chunk:
        # bm25 for chunks
        corpus = []
        chunk_list = sorted(list(chunk_dict.keys()))
        for chunk in chunk_list:
            corpus.append(chunk_dict[chunk])
        tokenized_corpus = [word_tokenizer.tokenize(doc) for doc in corpus]
        bm25 = BM25Okapi(tokenized_corpus)
    else:
        # bm25 for whole document
        corpus = []
        prcessed_list = sorted(file_list)
        for file in prcessed_list:
            corpus.append(processed_file_dict[file]['processed_document'])
        tokenized_corpus = [word_tokenizer.tokenize(doc) for doc in corpus]
        bm25 = BM25Okapi(tokenized_corpus)

    stopwords = nltk.corpus.stopwords.words('english')

    docs = []
    for file in file_list:
        docs.append(processed_file_dict[file]['processed_document'])

    count_vec = CountVectorizer(stop_words=stopwords)
    word_count_vector = count_vec.fit_transform(docs)


    tfidf_transformer=TfidfTransformer(smooth_idf=True,use_idf=True)
    tfidf_transformer.fit(word_count_vector)

    features = count_vec.get_feature_names_out()

    n_keywords = 25

    def extract_query(doc):
        tf_idf_vector=tfidf_transformer.transform(count_vec.transform([doc]))
        sorted_items = sort_coo(tf_idf_vector.tocoo())
        keywords=extract_topn_from_vector(features,sorted_items, n_keywords)
        return " ".join(list(keywords.keys()))

    query_dict = {}

    for file in tqdm(file_list):
        query_words = extract_query(processed_file_dict[file]['processed_document'])
        query_dict[file] = query_words

    n_candidates = 150
    candidate_dicts = {}

    for file in tqdm(file_list):
        query = query_dict[file]
        tokenized_query = word_tokenizer.tokenize(query)
        results = bm25.get_scores(tokenized_query)
        max_ids = np.argsort(results)[-n_candidates:]
        document_candidates = [file_list[idx] for idx in max_ids]
        candidate_dicts[file] = list(set(document_candidates))

    data_df = load_data(f'dataset/json/{mode}.json')
    data_df['candidates'] = data_df['source'].apply(lambda x: candidate_dicts[x])
    data_df['query'] = data_df['source'].apply(lambda x: query_dict[x])
    return data_df

def create_bm25_indexes(segment="test"):
    tmp_dir = "tmp"
    os.makedirs(tmp_dir, exist_ok=True)
    
    indexes_dir = f'bm25/{segment}'
    os.makedirs(indexes_dir, exist_ok=True)

    
    with open('dataset/all_data.json') as f:
        all_data_dict = json.load(f)

    file_list = [f for f in os.listdir(f'/home/s2310409/workspace/coliee-2024/dataset-2023/task1/{segment}_files') if f.endswith('.txt')]
    file_list = [f for f in file_list if f in all_data_dict.keys()]
    file_list = sorted(file_list)

    processed_file_dict = {}
    for file in [f for f in os.listdir("dataset/processed") if not f.startswith('.')]:
        processed_file = f"dataset/processed/{file}"
        with open(processed_file, 'r') as fp:
            processed_document = fp.read()
            processed_file_dict[file] = {
                'sentences': processed_document.split('\n\n'),
                'processed_document': processed_document
            }
    
    # data_df = build_df(mode=segment)
    # for case in tqdm(data_df['source'].values):
    #     base_case_num = case.split(".txt")[0]
    #     candidate_cases = data_df[data_df['source'] == case]['candidates'].values[0]
    #     for cand_case in candidate_cases:
    #         cand_case_data = processed_file_dict[cand_case]['processed_document']
    #         cand_num = cand_case.split(".txt")[0]
    #         dict_ = { "id": f"{base_case_num}_candidate{cand_num}.txt_task2", "contents": cand_case_data}
    #         with jsonlines.open(f"{tmp_dir}/candidate.jsonl", mode="a") as writer:
    #             writer.write(dict_)

    for case in tqdm(file_list):
        dict_  = { "id": f"{case}", "contents": processed_file_dict[case]['processed_document']}
        with jsonlines.open(f"{tmp_dir}/candidate.jsonl", mode="a") as writer:
            writer.write(dict_)

    subprocess.run(["/home/s2310409/miniconda3/envs/coliee-24/bin/python", "-m", "pyserini.index", "-collection", "JsonCollection",
                    "-generator", "DefaultLuceneDocumentGenerator", "-threads", "1", "-input",
                    f"{tmp_dir}", "-index", f"{indexes_dir}", "-storePositions", "-storeDocvectors",
                    "-storeRaw"])
    
create_bm25_indexes(segment="test")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1217/1217 [00:01<00:00, 720.86it/s]


2024-01-11 16:08:07,576 INFO  [main] index.IndexCollection (IndexCollection.java:391) - Setting log level to INFO
2024-01-11 16:08:07,577 INFO  [main] index.IndexCollection (IndexCollection.java:394) - Starting indexer...
2024-01-11 16:08:07,577 INFO  [main] index.IndexCollection (IndexCollection.java:396) - DocumentCollection path: tmp
2024-01-11 16:08:07,577 INFO  [main] index.IndexCollection (IndexCollection.java:397) - CollectionClass: JsonCollection
2024-01-11 16:08:07,577 INFO  [main] index.IndexCollection (IndexCollection.java:398) - Generator: DefaultLuceneDocumentGenerator
2024-01-11 16:08:07,577 INFO  [main] index.IndexCollection (IndexCollection.java:399) - Threads: 1
2024-01-11 16:08:07,577 INFO  [main] index.IndexCollection (IndexCollection.java:400) - Language: en
2024-01-11 16:08:07,578 INFO  [main] index.IndexCollection (IndexCollection.java:401) - Stemmer: porter
2024-01-11 16:08:07,578 INFO  [main] index.IndexCollection (IndexCollection.java:402) - Keep stopwords? fal

In [45]:
from pyserini.search import LuceneSearcher
from collections import defaultdict

def predict_bm25(searcher, query, case):
    bm25_score = defaultdict(lambda: 0)
    hits = searcher.search(query, k=10000)
    for hit in hits:
        if hit.docid != case:
            bm25_score[hit.docid] = max(hit.score, bm25_score[hit.docid])
    return bm25_score

def predict_all_bm25(dataset_path, bm25_index_path, eval_segment="test",
                     k1=None, b=None, topk=None):
    searcher = LuceneSearcher(bm25_index_path)
    if k1 and b:
        print(f"k1: {k1}, b: {b}")
        searcher.set_bm25(k1, b)

    # dataset_path = "/home/thanhtc/mnt/datasets/COLIEE2023/Task2/data_org"
    corpus_dir, cases_dir, _ = get_task2_data(dataset_path, eval_segment)
    bm25_scores = {}
    for case in cases_dir:
        base_case_data = preprocess_case_data(corpus_dir / case / "entailed_fragment.txt")
        score = predict_bm25(searcher, base_case_data, case)
        if topk is not None:
            sorted_score = sorted(score.items(), key=lambda x: x[1], reverse=True)[:topk]
            score = {x[0]: x[1] for x in sorted_score}
        bm25_scores[case] = score
    return bm25_scores

In [60]:
searcher = LuceneSearcher("bm25/test")

base_id = data_df.iloc[0].source
base_query = data_df.iloc[0].query
cand_case = data_df.iloc[0].candidates[0]

hits = searcher.search(base_query, k=10000)
print(base_id)
[(h.docid, h.score) for h in hits]

070318.txt


[('070318.txt', 52.63560104370117),
 ('069215.txt', 14.367500305175781),
 ('067598.txt', 14.207900047302246),
 ('025789.txt', 13.736499786376953),
 ('001589.txt', 13.735300064086914),
 ('071294.txt', 12.843000411987305),
 ('039755.txt', 12.746199607849121),
 ('006254.txt', 12.728699684143066),
 ('091641.txt', 12.6225004196167),
 ('025676.txt', 12.134400367736816),
 ('030534.txt', 12.133399963378906),
 ('071237.txt', 12.130900382995605),
 ('073854.txt', 12.090800285339355),
 ('037739.txt', 11.871399879455566),
 ('054863.txt', 10.905200004577637),
 ('024385.txt', 10.859800338745117),
 ('046028.txt', 10.798100471496582),
 ('088428.txt', 10.685400009155273),
 ('098933.txt', 10.271499633789062),
 ('032432.txt', 10.114299774169922),
 ('042703.txt', 10.109800338745117),
 ('069127.txt', 9.991600036621094),
 ('091307.txt', 9.930399894714355),
 ('012462.txt', 9.81029987335205),
 ('055289.txt', 9.728799819946289),
 ('031467.txt', 9.47249984741211),
 ('027423.txt', 9.468600273132324),
 ('048464.tx