This code is the implementation of methods proposed by TEAMX at COLIEE 2025 for task 3 (Statute Law retrieval). The data used can be accessed [here](https://drive.google.com/drive/folders/12XfVi-RUBEefB2avOlaU3aurK6S385pB?usp=drive_link).

# Data Preprocessing

In [None]:
from google.colab import drive
import xml.etree.ElementTree as ET
import re
import os
import pandas as pd
import numpy as np

In [None]:
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
def extract_articles(text):
    articles = []
    pattern = r'(^|\n)\s*Article (\d+(?:-\d+)?)(.*?)(?=(\n\s*Article \d+)|$)'
    matches = re.findall(pattern, text, re.DOTALL)

    for _, article_number, article_body, _ in matches:
        articles.append([article_number.strip(), article_body.strip()])
    return articles


def read_xml_train(folder):
    all_data = []

    if not os.path.exists(folder):
        print(f"Error: Folder '{folder}' not found.")
        return all_data

    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        if os.path.isfile(file_path):
            print(f"Reading file: {filename}")

            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    file_content = f.read()

                    root = ET.fromstring(file_content)
                    for pair in root.findall('pair'):
                        question_id = pair.get('id')
                        label = pair.get('label')
                        relevant_text = pair.find('t1').text
                        question_text = pair.find('t2').text.strip()
                        articles = extract_articles(relevant_text)
                        for article_number, article_body in articles:
                            all_data.append([question_id, question_text, article_number, article_body, label])
            except Exception as e:
                print(f"Error reading file '{filename}': {e}")
        else:
            print(f"Skipping directory: {filename}")

    header = ['qid', 'query', 'docno', 'text', 'label']
    return pd.DataFrame(all_data, columns=header)

In [None]:
def read_xml_test(file_path):
  all_data = []

  if os.path.isfile(file_path):
    print(f"Reading file: {file_path}")

    try:
      with open(file_path, 'r') as f:
        file_content = f.read()

        root = ET.fromstring(file_content)
        for pair in root.findall('pair'):
          question_id = pair.get('id')
          question_text = pair.find('t2').text.strip()
          all_data.append([question_id, question_text])
    except Exception as e:
      print(f"Error reading file '{file_path}': {e}")
  else:
    print(f"Skipping directory: {file_path}")

  header = ['qid', 'query']
  return pd.DataFrame(all_data, columns=header)

In [None]:
def read_txt_civil_code(file_path):
  data = []

  with open(file_path, 'r') as file:
      lines = file.readlines()

  part_pattern = re.compile(r'^Part (\w+) (.+)$')
  chapter_pattern = re.compile(r'^Chapter (\w+) (.+)$')
  section_pattern = re.compile(r'^Section (\w+) (.+)$')
  article_title_pattern = re.compile(r'^\((.+)\)$')
  article_pattern = re.compile(r'^Article (\d+(-\d+)?)\s+(.*)$')

  part_no = part_text = chapter_no = chapter_text = section_no = section_text = article_title = ''
  current_article_no = ''
  current_article_text = ''

  for line in lines:
      line = line.strip()

      if part_match := part_pattern.match(line):
          if current_article_no:
              data.append({
                  'part_no': part_no,
                  'part_text': part_text,
                  'chapter_no': chapter_no,
                  'chapter_text': chapter_text,
                  'section_no': section_no,
                  'section_text': section_text,
                  'article_title': article_title,
                  'article_no': current_article_no,
                  'article_text': current_article_text.strip()
              })
          part_no, part_text = part_match.groups()
          chapter_no = chapter_text = section_no = section_text = article_title = current_article_no = current_article_text = ''

      elif chapter_match := chapter_pattern.match(line):
          if current_article_no:
              data.append({
                  'part_no': part_no,
                  'part_text': part_text,
                  'chapter_no': chapter_no,
                  'chapter_text': chapter_text,
                  'section_no': section_no,
                  'section_text': section_text,
                  'article_title': article_title,
                  'article_no': current_article_no,
                  'article_text': current_article_text.strip()
              })
          chapter_no, chapter_text = chapter_match.groups()
          section_no = section_text = article_title = current_article_no = current_article_text = ''

      elif section_match := section_pattern.match(line):
          if current_article_no:
              data.append({
                  'part_no': part_no,
                  'part_text': part_text,
                  'chapter_no': chapter_no,
                  'chapter_text': chapter_text,
                  'section_no': section_no,
                  'section_text': section_text,
                  'article_title': article_title,
                  'article_no': current_article_no,
                  'article_text': current_article_text.strip()
              })
          section_no, section_text = section_match.groups()
          article_title = current_article_no = current_article_text = ''

      elif article_title_match := article_title_pattern.match(line):
          if current_article_no:
              data.append({
                  'part_no': part_no,
                  'part_text': part_text,
                  'chapter_no': chapter_no,
                  'chapter_text': chapter_text,
                  'section_no': section_no,
                  'section_text': section_text,
                  'article_title': article_title,
                  'article_no': current_article_no,
                  'article_text': current_article_text.strip()
              })
          article_title = article_title_match.group(1)
          current_article_no = current_article_text = ''

      elif article_match := article_pattern.match(line):
          if current_article_no:
              data.append({
                  'part_no': part_no,
                  'part_text': part_text,
                  'chapter_no': chapter_no,
                  'chapter_text': chapter_text,
                  'section_no': section_no,
                  'section_text': section_text,
                  'article_title': article_title,
                  'article_no': current_article_no,
                  'article_text': current_article_text.strip()
              })

          current_article_no, _, current_article_text = article_match.groups()

      else:
          current_article_text += ' ' + line

  if current_article_no:
      data.append({
          'part_no': part_no,
          'part_text': part_text,
          'chapter_no': chapter_no,
          'chapter_text': chapter_text,
          'section_no': section_no,
          'section_text': section_text,
          'article_title': article_title,
          'article_no': current_article_no,
          'article_text': current_article_text.strip()
      })

  df = pd.DataFrame(data)
  df['full_text'] = df.apply(lambda row: '. '.join([str(row[col]) for col in ['part_text', 'chapter_text', 'section_text', 'article_title', 'article_text'] if not pd.isna(row[col])]), axis=1)

  return df

In [None]:
def read_xml_test_2025(file_path):
    all_data = []

    if os.path.isfile(file_path):
        print(f"Reading file: {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                file_content = f.read()
                root = ET.fromstring(file_content)

                for pair in root.findall('pair'):
                    question_id = pair.get('id')

                    t1_text = pair.find('t1').text.strip() if pair.find('t1') is not None else ""
                    t2_text = pair.find('t2').text.strip() if pair.find('t2') is not None else ""

                    article_matches = re.findall(r'Article\s(\d+[-\d+]*)\s(.*?)(?=\nArticle\s\d+|\Z)', t1_text, re.DOTALL)

                    if article_matches:
                        for article, article_text in article_matches:
                            all_data.append([question_id, article, article_text.strip(), t2_text])
                    else:
                        all_data.append([question_id, "", "", t2_text])

        except Exception as e:
            print(f"Error reading file '{file_path}': {e}")
    else:
        print(f"Skipping directory: {file_path}")

    return all_data

In [None]:
TRAIN_PATH_SRC = '/content/drive/MyDrive/data/train'
TRAIN_PATH_DST = '/content/drive/MyDrive/data/train.csv'

CIVIL_CODE_PATH_SRC = '/content/drive/MyDrive/data/civil_code_en-1to724-2.txt'
CIVIL_CODE_PATH_DSC = '/content/drive/MyDrive/data/civil_code.csv'

In [None]:
train = read_xml_train(TRAIN_PATH_SRC)
train.to_csv(TRAIN_PATH_DST, index=False)
train.head(2)

Reading file: riteval_R02_en.xml
Reading file: riteval_H18_en.xml
Reading file: riteval_R05_en.xml
Reading file: riteval_R03_en.xml
Reading file: riteval_H24_en.xml
Reading file: riteval_R01_en.xml
Reading file: riteval_H29_en.xml
Reading file: riteval_R04_en.xml
Reading file: riteval_H21_en.xml
Reading file: riteval_H27_en.xml
Reading file: riteval_H22_en.xml
Reading file: riteval_H19_en.xml
Reading file: riteval_H25_en.xml
Reading file: riteval_H23_en.xml
Reading file: riteval_H26_en.xml
Reading file: riteval_H28_en.xml
Reading file: riteval_H20_en.xml
Reading file: riteval_H30_en.xml


Unnamed: 0,qid,query,docno,text,label
0,R02-1-A,The family court may decide to commence an ass...,15,(1) The family court may decide to commence an...,N
1,R02-1-A,The family court may decide to commence an ass...,11,The family court may decide to commence a cura...,N


In [None]:
civil_code = read_txt_civil_code(CIVIL_CODE_PATH_SRC)
civil_code.to_csv(CIVIL_CODE_PATH_DSC, index=False)
civil_code.head(2)

Unnamed: 0,part_no,part_text,chapter_no,chapter_text,section_no,section_text,article_title,article_no,article_text,full_text
0,I,General Provisions,I,Common Provisions,,,Fundamental Principles,1,(1) Private rights must be congruent with the ...,General Provisions. Common Provisions. . Funda...
1,I,General Provisions,I,Common Provisions,,,Standards for Construction,2,This Code must be construed so as to honor the...,General Provisions. Common Provisions. . Stand...


# Indexing

In [None]:
import re
import os
import pandas as pd
import numpy as np

In [None]:
!pip install --upgrade python-terrier -q
import pyterrier as pt

if not pt.started():
  pt.init()

  if not pt.started():
Java started and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]
java is now started automatically with default settings. To force initialisation early, run:
pt.java.init() # optional, forces java initialisation
  pt.init()


In [None]:
train = pd.read_csv('/content/drive/MyDrive/data/train.csv')
civil_code = pd.read_csv('/content/drive/MyDrive/data/civil_code.csv')

In [None]:
def remove_nonalphanum(text):
  pattern = re.compile('[\W_]+')
  return pattern.sub(' ', text)

In [None]:
queries = train[['qid', 'query']].drop_duplicates()
queries['query'] = queries['query'].apply(lambda x: x.strip()).apply(remove_nonalphanum)

queries_val = queries[queries["qid"].astype(str).str.startswith("R05")]
queries_train = queries[~queries["qid"].astype(str).str.startswith("R05")]

In [None]:
qrels = train[['qid', 'query', 'docno']].drop_duplicates()
qrels['query'] = qrels['query'].apply(lambda x: x.strip()).apply(remove_nonalphanum)
qrels['docno'] = qrels['docno'].astype(str)
qrels['relevance'] = 1
qrels['label'] = 1

qrels_val = qrels[qrels["qid"].astype(str).str.startswith("R05")]
qrels_train = qrels[~qrels["qid"].astype(str).str.startswith("R05")]

In [None]:
collection = civil_code[['article_no', 'full_text']].rename(columns={'article_no': 'docno', 'full_text': 'text'}).drop_duplicates()
collection['docno'] = collection['docno'].astype(str)
collection['text'] = collection['text'].apply(remove_nonalphanum)

In [None]:
%%time

!rm -rf ./index

pd_indexer = pt.DFIndexer("./index", \
                          type = pt.index.IndexingType(1), \
                          tokeniser = pt.index.TerrierTokeniser('utf'), \
                          stemmer = pt.index.TerrierStemmer('porter'), \
                          stopwords = pt.index.TerrierStopwords('terrier'), \
                          blocks = True, \
                          verbose = True)

index_ref = pd_indexer.index(collection["text"], collection)



  0%|          | 0/776 [00:00<?, ?documents/s]

CPU times: user 5.24 s, sys: 156 ms, total: 5.39 s
Wall time: 3.76 s


In [None]:
def compute_metrics(df):

    precisions, recalls = [], []

    for qid, group in df.groupby("qid"):
        correct_retrieved = ((group["relevance"] == 1) & (group["pred"] == 1)).sum()
        total_retrieved = (group["pred"] == 1).sum()
        total_relevant = (group["relevance"] == 1).sum()

        precision = correct_retrieved / total_retrieved if total_retrieved > 0 else 0
        recall = correct_retrieved / total_relevant if total_relevant > 0 else 0

        precisions.append(precision)
        recalls.append(recall)

    avg_precision = sum(precisions) / len(precisions) if precisions else 0
    avg_recall = sum(recalls) / len(recalls) if recalls else 0

    if avg_precision + avg_recall > 0:
        f2_score = (5 * avg_precision * avg_recall) / (4 * avg_precision + avg_recall)
    else:
        f2_score = 0

    return avg_precision, avg_recall, f2_score

# XXthr

In [None]:
class ThresholdWithFallback(pt.Transformer):
  def __init__(self, threshold, fallback=1):
    self.threshold = threshold
    self.fallback = fallback

  def transform(self, res: pd.DataFrame) -> pd.DataFrame:
    filtered_dfs = []

    for qid in res['qid'].unique():
        temp = res[res['qid'] == qid]

        above_threshold = temp[temp['score'] >= self.threshold]

        if above_threshold.empty:
            fallback = temp.nlargest(self.fallback, 'score')
            filtered_dfs.append(fallback)
        else:
            filtered_dfs.append(above_threshold)

    final_df = pd.concat(filtered_dfs, ignore_index=True)
    return final_df

In [None]:
bm25 = pt.BatchRetrieve(index_ref, wmodel="BM25") % 50

  bm25 = pt.BatchRetrieve(index_ref, wmodel="BM25") % 50


In [None]:
best_threshold = 25
best_f2 = 0
threshold_scores = []

for thr in [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60]:
    bm25_train = bm25 >> ThresholdWithFallback(threshold=thr, fallback=1)
    serp = bm25_train.transform(queries_train)[['qid', 'docno', 'rank', 'score']].sort_values(['qid', 'rank']).drop_duplicates()
    to_measure = pd.merge(qrels_train, serp, on=['qid', 'docno'], how='outer', suffixes=('_qrels', '_serp'))

    to_measure['relevance'] = to_measure['relevance'].apply(lambda x: 1 if x == 1 else 0)
    to_measure['pred'] = to_measure['score'].apply(lambda x: 1 if x > 0 else 0)

    precision, recall, f2 = compute_metrics(to_measure)

    threshold_scores.append((thr, precision, recall, f2))

    print(f"Threshold: {thr}, Precision: {precision:.4f}, Recall: {recall:.4f}, F2-score: {f2:.4f}")

    if f2 > best_f2:
        best_f2 = f2
        best_threshold = thr

print(f"\nBest Threshold: {best_threshold} with F2-score: {best_f2:.4f}")

In [None]:
bm25_threshold = bm25 >> ThresholdWithFallback(threshold=best_threshold, fallback=1)
uithr = bm25_threshold.transform(queries_val)[['qid', 'docno', 'rank', 'score']].sort_values(['qid', 'rank']).drop_duplicates()
to_measure = pd.merge(qrels_val, uithr, on=['qid', 'docno'], how='outer', suffixes=('_qrels', '_serp'))
to_measure['relevance'] = to_measure['relevance'].apply(lambda x: 1 if x == 1 else 0)
to_measure['pred'] = to_measure['score'].apply(lambda x: 1 if x > 0 else 0)

precision, recall, f2 = compute_metrics(to_measure)
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F2-score: {f2:.4f}")

Precision: 0.4387, Recall: 0.6468, F2-score: 0.5907


# XXwa

In [None]:
bm25 = pt.BatchRetrieve(index_ref, wmodel="BM25") % 20
serp = bm25.transform(queries_train)[['qid', 'docno', 'rank', 'score']].sort_values(['qid', 'rank']).drop_duplicates()

  bm25 = pt.BatchRetrieve(index_ref, wmodel="BM25") % 20


In [None]:
merged_serp = pd.merge(qrels_train, serp, on=['qid', 'docno'], how='outer', suffixes=('_qrels', '_serp'))
merged_serp = pd.merge(merged_serp, qrels[['qid', 'query']], left_on='qid', right_on='qid', how='left')
merged_serp = pd.merge(merged_serp, civil_code[['article_no', 'full_text']], left_on=['docno'], right_on=['article_no'], how='left')
merged_serp = merged_serp[['qid', 'query_y', 'docno', 'full_text', 'score', 'rank', 'relevance']].drop_duplicates().rename(columns={'full_text': 'passage','query_y': 'query'}).sort_values(['qid', 'rank'])

merged_serp['relevance'] = merged_serp['relevance'].apply(lambda x: 1 if x == 1 else 0)
merged_serp['pred'] = merged_serp['score'].apply(lambda x: 1 if x > 0 else 0)
merged_serp.head(2)

Unnamed: 0,qid,query,docno,passage,score,rank,relevance,pred
13,H18-1-1,A special provision that releases warranty can...,572,Claims. Contracts. Sale. Special Agreement Dis...,27.201336,0.0,1,1
0,H18-1-1,A special provision that releases warranty can...,261,Real Rights. Ownership. Co-Ownership. Co-Owner...,17.141129,1.0,0,1


In [None]:

!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
# !pip install --no-deps unsloth==2025.2.15

Collecting xformers==0.0.29
  Using cached xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Using cached xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl (15.3 MB)
Installing collected packages: xformers
  Attempting uninstall: xformers
    Found existing installation: xformers 0.0.29.post3
    Uninstalling xformers-0.0.29.post3:
      Successfully uninstalled xformers-0.0.29.post3
Successfully installed xformers-0.0.29


In [None]:
# !pip install unsloth
# !pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.6.0+cu124)
    Python  3.11.11 (you have 3.11.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

In [None]:
prompt_train  = """Below is a pair of legal query and document. Determine if the pair has semantic relevance, if no then 0, if yes then 1

### Query : {}

### Document : {}

### Relevance : {}
"""

prompt_test  = """Below is a pair of legal query and document. Determine if the pair has semantic relevance, if no then 0, if yes then 1.

### Query : {}

### Document : {}

### Relevance :
"""

merged_serp['prompt_train'] = merged_serp.apply(lambda row: prompt_train.format(row['query'], row['passage'], row['relevance']) + tokenizer.eos_token, axis=1)

In [None]:
merged_serp.head(1)

Unnamed: 0,qid,query,docno,passage,score,rank,relevance,pred,prompt_train
13,H18-1-1,A special provision that releases warranty can...,572,Claims. Contracts. Sale. Special Agreement Dis...,27.201336,0.0,1,1,Below is a pair of legal query and document. D...


In [None]:
merged_serp_balanced = merged_serp[(merged_serp['relevance'] == 1) | (merged_serp['rank'] <= 3)]

from datasets import Dataset
dataset = Dataset.from_pandas(merged_serp_balanced[['prompt_train']].rename(columns={'prompt_train': 'text'}))

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field = "text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        max_steps=1,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        report_to="none",
    ),
)

In [None]:
trainer_stats = trainer.train()

In [None]:
FastLanguageModel.for_inference(model)
def infer(txt):
  inputs = tokenizer([txt], return_tensors = "pt").to("cuda")
  outputs = model.generate(**inputs, max_new_tokens = 64)
  print(tokenizer.batch_decode(outputs)[0])
  mtch = re.search(r"### Relevance :\n(\d+)", tokenizer.batch_decode(outputs)[0])
  return int(mtch.group(1)) if mtch else None  # Return None if no match

In [None]:
from sklearn.preprocessing import MinMaxScaler

In [None]:
queries_train_sample = queries_train.sample(n=60, random_state=42)

bm25_5 = pt.BatchRetrieve(index_ref, wmodel="BM25") % 5
serp_sample = bm25_5.transform(queries_train_sample)[['qid', 'docno', 'rank', 'score']].sort_values(['qid', 'rank']).drop_duplicates()

serp_sample = pd.merge(serp_sample, qrels[['qid', 'query']], left_on='qid', right_on='qid', how='left')
serp_sample = pd.merge(serp_sample, civil_code[['article_no', 'full_text']], left_on=['docno'], right_on=['article_no'], how='left')
serp_sample = serp_sample[['qid', 'query', 'docno', 'full_text', 'score', 'rank']].drop_duplicates().rename(columns={'full_text': 'passage',}).sort_values(['qid', 'rank'])
serp_sample['prompt_test'] = serp_sample.apply(lambda row: prompt_test.format(row['query'], row['passage']) + tokenizer.eos_token, axis=1)

serp_sample['llm'] = serp_sample['prompt_test'].apply(infer)
scaler = MinMaxScaler()
serp_sample["bm25"] = scaler.fit_transform(serp_sample[["score"]])

  bm25_5 = pt.BatchRetrieve(index_ref, wmodel="BM25") % 5


In [None]:
best_alpha = 0
best_f2 = 0
alpha_scores = []

for alpha in [0.1, 0.5, 1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95]:
    to_measure = serp_sample
    to_measure['average'] = (1 - alpha / 100) * to_measure['bm25'] + (alpha / 100) * to_measure['llm']
    to_measure['uiwa'] = (to_measure['average'] >= 0.5).astype(int)

    for qid, group in to_measure.groupby('qid'):
        if group['uiwa'].sum() == 0:
            max_index = group['average'].idxmax()
            to_measure.loc[max_index, 'uiwa'] = 1

    to_measure = pd.merge(qrels_train[qrels_train['qid'].isin(to_measure['qid'])], to_measure, on=['qid', 'docno'], how='outer', suffixes=('_qrels', '_serp'))
    to_measure['relevance'] = to_measure['relevance'].apply(lambda x: 1 if x == 1 else 0)
    to_measure['uiwa'] = to_measure['uiwa'].apply(lambda x: x if x > 0 else 0)

    precision, recall, f2 = compute_metrics(to_measure[['qid', 'docno', 'relevance', 'uiwa']].rename(columns={'uiwa': 'pred'}))
    alpha_scores.append((alpha, precision, recall, f2))

    if f2 > best_f2:
        best_f2 = f2
        best_alpha = alpha

to_measure = serp_sample
serp_sample['average'] = (1 - alpha / 100) * serp_sample['bm25'] + (alpha / 100) * serp_sample['llm']
serp_sample['uiwa'] = (serp_sample['average'] >= 0.5).astype(int)

for qid, group in serp_sample.groupby('qid'):
    if group['uiwa'].sum() == 0:
        max_index = group['average'].idxmax()
        serp_sample.loc[max_index, 'uiwa'] = 1

In [None]:
serp_val = bm25_5.transform(queries_val)[['qid', 'docno', 'rank', 'score']].sort_values(['qid', 'rank']).drop_duplicates()

serp_val = pd.merge(serp_val, qrels[['qid', 'query']], left_on='qid', right_on='qid', how='left')
serp_val = pd.merge(serp_val, civil_code[['article_no', 'full_text']], left_on=['docno'], right_on=['article_no'], how='left')
serp_val = serp_val[['qid', 'query', 'docno', 'full_text', 'score', 'rank']].drop_duplicates().rename(columns={'full_text': 'passage',}).sort_values(['qid', 'rank'])
serp_val['prompt_test'] = serp_val.apply(lambda row: prompt_test.format(row['query'], row['passage']) + tokenizer.eos_token, axis=1)

serp_val['llm'] = serp_val['prompt_test'].apply(infer)
serp_val["bm25"] = scaler.transform(serp_val[["score"]])
serp_val['average'] = (1 - best_alpha / 100) * serp_val['bm25'] + (best_alpha / 100) * serp_val['llm']
serp_val['uiwa'] = (serp_val['average'] > 0.5).astype(int)

for qid, group in serp_val.groupby('qid'):
    if group['uiwa'].sum() == 0:
        max_index = group['average'].idxmax()
        serp_val.loc[max_index, 'uiwa'] = 1

serp_val = pd.merge(qrels_val[qrels_val['qid'].isin(serp_val['qid'])], serp_val, on=['qid', 'docno'], how='outer', suffixes=('_qrels', '_serp'))
serp_val['relevance'] = serp_val['relevance'].apply(lambda x: 1 if x == 1 else 0)
serp_val['uiwa'] = serp_val['uiwa'].apply(lambda x: x if x > 0 else 0)

precision, recall, f2 = compute_metrics(serp_val[['qid', 'docno', 'relevance', 'uiwa']].rename(columns={'uiwa': 'pred'}))
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F2-score: {f2:.4f}")

Precision: 0.4820, Recall: 0.6263, F2-score: 0.5909


# XXmeta

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

In [None]:
serp_sample['article_length'] = serp_sample['passage'].apply(lambda x: len(x.split()))
serp_sample['question_length'] = serp_sample['query'].apply(lambda x: len(x.split()))

serp_val['article_length'] = serp_val['passage'].apply(lambda x: len(x.split()))
serp_val['question_length'] = serp_val['query'].apply(lambda x: len(x.split()))

In [None]:
def get_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state
    cls_embedding = embeddings[:, 0, :]

    return cls_embedding

In [None]:
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")

serp_sample['query_embedding'] = serp_sample['query'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))
serp_sample['passage_embedding'] = serp_sample['passage'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))
serp_sample['bert'] = serp_sample.apply(lambda row: cosine_similarity(
    row["query_embedding"].reshape(1, -1),
    row["passage_embedding"].reshape(1, -1)
)[0][0], axis=1)

serp_val['query_embedding'] = serp_val['query'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))
serp_val['passage_embedding'] = serp_val['passage'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))
serp_val['bert'] = serp_val.apply(lambda row: cosine_similarity(
    row["query_embedding"].reshape(1, -1),
    row["passage_embedding"].reshape(1, -1)
)[0][0], axis=1)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  serp_sample['query_embedding'] = serp_sample['query'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))
  serp_sample['passage_embedding'] = serp_sample['passage'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))
  serp_val['query_embedding'] = serp_val['query'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))
  serp_val['passage_embedding'] = serp_val['passage'].apply(lambda x: get_embedding(x)).apply(lambda x: np.array(x))


In [None]:
X = serp_sample[['bm25', 'pred', 'uiwa', 'bert', 'article_length', 'question_length']]
y = serp_sample['relevance']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)

model = LogisticRegression()
model.fit(X_train, y_train)
X_test['lr'] = model.predict_proba(X_test)[:, 1]

X_test[['qid', 'docno', 'relevance']] = serp_sample[['qid', 'docno', 'relevance']].loc[X_test.index]


In [None]:
best_proba = 0
best_f2 = 0
proba_scores = []

for proba in [1, 2, 5, 10, 15, 20, 25, 30, 35, 40, 45]:

    X_test['uimeta'] = X_test['lr'].apply(lambda x: 1 if x > proba/100 else 0)
    for qid, group in X_test.groupby('qid'):
      if group['uimeta'].sum() == 0:
          max_index = group['lr'].idxmax()
          X_test.loc[max_index, 'uimeta'] = 1

    X_test['uimeta'] = X_test['uiwa'].apply(lambda x: x if x > 0 else 0)

    precision, recall, f2 = compute_metrics(X_test[['qid', 'docno', 'relevance', 'uimeta']].rename(columns={'uimeta': 'pred'}))
    proba_scores.append((proba, precision, recall, f2))


    if f2 > best_f2:
        best_f2 = f2
        best_proba = proba

In [None]:
serp_val['lr'] = model.predict_proba(serp_val[['bm25', 'pred', 'uiwa', 'bert', 'article_length', 'question_length']])[:, 1]
serp_val['uimeta'] = serp_val['lr'].apply(lambda x: 1 if x > best_proba/100 else 0)
for qid, group in serp_val.groupby('qid'):
  if group['uimeta'].sum() == 0:
      max_index = group['lr'].idxmax()
      serp_val.loc[max_index, 'uimeta'] = 1

serp_val['uimeta'] = serp_val['uiwa'].apply(lambda x: x if x > 0 else 0)
precision, recall, f2 = compute_metrics(serp_val[['qid', 'docno', 'relevance', 'uimeta']].rename(columns={'uimeta': 'pred'}))

print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F2-score: {f2:.4f}")

Precision: 0.4927, Recall: 0.6350, F2-score: 0.6003
