In [1]:
from datasets import load_dataset, DatasetDict
import numpy as np
from tqdm import tqdm
import os
import json
from sentence_transformers import SentenceTransformer, util
import torch
import argparse
import bm25s

In [2]:
chunk_size = 300
chunk_overlap = 0.0

num_proc = os.cpu_count() - 3

In [3]:
dataset = "oal_qa"

key_field = "docid"
if dataset == "clerc":
    original_dataset = load_dataset("jhu-clsp/CLERC", data_files={"train": f"generation/train.jsonl",  "test": f"generation/test.jsonl"})
    workshop_hf_name = f"CLERC-generation-workshop"
elif dataset == "echr":
    workshop_hf_name = f"ECHR-generation-workshop"
elif dataset == "echr_qa":
    workshop_hf_name = f"ECHR_QA-generation-workshop"
elif dataset == "obli_qa":
    workshop_hf_name = f"OBLI_QA-generation-workshop"
elif dataset == "cuad":
    workshop_hf_name = f"CUAD-generation-workshop"
elif dataset == "oal_qa":
    workshop_hf_name = f"OAL_QA-generation-workshop"
else:
    raise ValueError("Invalid dataset")
current_chosen_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="data")

README.md:   0%|          | 0.00/534 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/35.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2024 [00:00<?, ? examples/s]

In [4]:
current_chosen_dataset

DatasetDict({
    train: Dataset({
        features: ['docid', 'previous_text', 'gold_text', 'citations'],
        num_rows: 100
    })
    test: Dataset({
        features: ['docid', 'previous_text', 'gold_text', 'citations'],
        num_rows: 2024
    })
})

In [5]:
# def chunk_document_into_passages_overlap_split(document_id: str, text: str, max_len: int, overlap: float):
#     if max_len <= 0:
#         raise ValueError("max_len must be a positive integer")
#     if not (0 <= overlap < 1):
#         raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

#     words = text.split()
#     chunks = []
#     step = int(max_len * (1 - overlap))
#     last_index = 0
#     for index, i in enumerate(range(0, len(words), step)):
#         current_splits = [document_id]
#         if i == 0:
#             second_split = words[i:i+max_len]
#             last_index = i+max_len
#             splits = ["", " ".join(second_split)]
#         else:
#             first_split = words[last_index-step:last_index]
#             second_split = words[last_index:last_index+step]
#             last_index = last_index+step
#             splits = [" ".join(first_split), " ".join(second_split)]
#         if splits:
#             current_splits.extend(splits)
#             chunks.append(current_splits)
#     return chunks

# def chunk_citations_overlap_split(record):
#     chunks = []
#     for citation_id, citation in record['citations']:
#         chunks.extend(chunk_document_into_passages_overlap_split(citation_id, citation, chunk_size, chunk_overlap))
#     return chunks

# if 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding oracle_documents_passages_overlap_split to {workshop_hf_name}")
#     if "citations" not in current_chosen_dataset.column_names['train'] or "citations" not in current_chosen_dataset.column_names['test']:
#         current_chosen_dataset = DatasetDict({
#             'train': current_chosen_dataset['train'].add_column(name="citations", column=original_dataset['train']['citations']),
#             'test': current_chosen_dataset['test'].add_column(name="citations", column=original_dataset['test']['citations'])
#         })
#     current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages_overlap_split': chunk_citations_overlap_split(record)})
#     current_chosen_dataset = current_chosen_dataset.select_columns(['docid', 'previous_text', 'gold_text', 'citations' ,'oracle_documents_passages', 'oracle_documents_passages_overlap_split'])
#     current_chosen_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}")
# else:
#     print(f"[!] oracle_documents_passages_overlap_split already exists in {workshop_hf_name}")

In [6]:
# def chunk_document_into_passages(document_id: str, text: str, max_len: int, overlap: float):
#     if max_len <= 0:
#         raise ValueError("max_len must be a positive integer")
#     if not (0 <= overlap < 1):
#         raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

#     words = text.split()
#     chunks = []
#     step = int(max_len * (1 - overlap))
#     for i in range(0, len(words), step):
#         chunk_words = words[i:i+max_len]
#         if chunk_words:
#             chunk_text = ' '.join(chunk_words)
#             chunks.append([document_id, chunk_text])
#     return chunks

# def chunk_citations(record):
#     chunks = []
#     for citation_id, citation in record['citations']:
#         chunks.extend(chunk_document_into_passages(citation_id, citation, chunk_size, chunk_overlap))
#     return chunks

# if 'oracle_documents_passages' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages' not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding oracle_documents_passages to {workshop_hf_name}")
#     current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages': chunk_citations(record)}, num_proc=num_proc)
# else:
#     print(f"[!] oracle_documents_passages already exists in {workshop_hf_name}")
# # current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages'])
# current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages', 'oracle_documents_oracle_passages'])
# current_chosen_dataset.push_to_hub(workshop_hf_name)

In [7]:
# import spacy

# nlp = spacy.load("en_core_web_sm")

# def chunk_document_into_passages(document_id: str, text: str, max_len: int = 300, overlap: float = 0.0):
#     if max_len <= 0:
#         raise ValueError("max_len must be a positive integer")
#     if not (0 <= overlap < 1):
#         raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

#     doc = nlp(text)
#     sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]

#     chunks = []
#     chunk_text = ""
#     overlap_len = int(overlap * max_len)

#     for sentence in sentences:
#         to_add = sentence
#         while to_add:
#             available = max_len - len(chunk_text)
#             if len(to_add) <= available:
#                 if chunk_text:
#                     chunk_text += " " + to_add
#                 else:
#                     chunk_text = to_add
#                 to_add = ""
#             else:
#                 part = to_add[:available]
#                 remainder = to_add[available:].strip()
#                 if chunk_text:
#                     chunk_text += " " + part
#                 else:
#                     chunk_text = part
#                 chunks.append([document_id, chunk_text])
#                 if overlap_len > 0:
#                     overlap_str = chunk_text[-overlap_len:].strip()
#                     chunk_text = overlap_str
#                 else:
#                     chunk_text = ""
#                 to_add = remainder
#     if chunk_text:
#         chunks.append([document_id, chunk_text])
#     return chunks

# def chunk_citations(record):
#     chunks = []
#     for citation_id, citation in record['citations']:
#         chunks.extend(chunk_document_into_passages(citation_id, citation, chunk_size, chunk_overlap))
#     return chunks

# if 'oracle_documents_passages' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages' not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding oracle_documents_passages to {workshop_hf_name}")
#     current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages': chunk_citations(record)}, num_proc=num_proc)
# else:
#     print(f"[!] oracle_documents_passages already exists in {workshop_hf_name}")
# # current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages'])
# current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages'])
# current_chosen_dataset.push_to_hub(workshop_hf_name, data_dir="data")

In [5]:
from more_itertools import windowed
import spacy

# nlp = spacy.load("en_core_web_sm")

# def chunk_document_into_passages(document_id: str, text: str, max_len: int = 300, overlap: float = 0.0):
#     if max_len <= 0:
#         raise ValueError("max_len must be a positive integer")
#     if not (0 <= overlap < 1):
#         raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")
        
#     doc = nlp(text)
#     words = [token.text for token in doc if not token.is_space]

#     chunks = []
#     chunk_words = []
#     current_word_count = 0  # word count of the current chunk
#     overlap_count = int(overlap * max_len)

#     for word in words:
#         if current_word_count + 1 > max_len:
#             chunk_text = " ".join(chunk_words).strip()
#             if chunk_text:
#                 chunks.append([document_id, chunk_text])
#             if overlap_count > 0:
#                 overlap_words = chunk_words[-overlap_count:]
#                 chunk_words = overlap_words
#                 current_word_count = len(chunk_words)
#             else:
#                 chunk_words = []
#                 current_word_count = 0
            
#             chunk_words.append(word)
#             current_word_count += 1
#         else:
#             chunk_words.append(word)
#             current_word_count += 1

#     if chunk_words:
#         chunk_text = " ".join(chunk_words).strip()
#         if chunk_text:
#             chunks.append([document_id, chunk_text])
#     return chunks

def chunk_document_into_passages(document_id: str, text: str, max_len: int = 300, overlap: float = 0.0):
    if max_len <= 0:
        raise ValueError("max_len must be a positive integer")
    if not (0 <= overlap < 1):
        raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")
    chunks = []
    words = text.split(" ")
    chunked_words = windowed(words, max_len, fillvalue="", step=int(max_len * (1 - overlap)))
    for chunk in chunked_words:
        chunks.append([document_id, " ".join(chunk).strip()])
    return chunks

def chunk_citations(record, chunk_size=300, chunk_overlap=0.0):
    chunks = []
    for citation_id, citation in record['citations']:
        chunks.extend(chunk_document_into_passages(citation_id, citation, max_len=chunk_size, overlap=chunk_overlap))
    return chunks

In [6]:
print(json.dumps(chunk_citations(current_chosen_dataset['test'][0]), indent=4))

[
    [
        "nsw_caselaw:549fc6183004262463bb648a",
        "New South Wales\nSupreme Court\n  CITATION :         Nasr v NRMA Insurance [2006] NSWSC 1018\n  HEARING DATE(S) :  27 September 2006\n  JUDGMENT DATE :    29 September 2006\n  JURISDICTION :     Common Law Division\n  JUDGMENT OF :      Associate Justice Harrison\n  DECISION :         (1) The appeal is dismissed; (2) The orders of the Magistrate Lulham dated 4 October 2005 are affirmed; (3) The summons filed 8 June 2006 is dismissed; (4) The plaintiff is to pay the defendant's costs as agreed or assessed.\n  CATCHWORDS :         Appeal decision of Local Court Magistrate - non-appearance - Statement of Claim struck out\n  LEGISLATION CITED :  Local Courts Act 1982 - s 73\n                       Allan v Kerr & Anor (1995) Aust Torts Reports 81-354\n                       Azzopardi v Tasman UEB Industries Ltd (1985) 4 NSWLR 139\n                       Carr v Neill [1999] NSWSC 1263\n                       Devries v Australia

In [13]:
create = False
if create or 'oracle_documents_passages' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages' not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding oracle_documents_passages to {workshop_hf_name}")
    # Use parallel processing with num_proc if dataset is large. Increase num_proc as needed.
    current_chosen_dataset = current_chosen_dataset.map(
        lambda record: {'oracle_documents_passages': chunk_citations(record)},
        num_proc=num_proc
    )
else:
    print(f"[!] oracle_documents_passages already exists in {workshop_hf_name}")

current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations', 'oracle_documents_passages'])
current_chosen_dataset.push_to_hub(workshop_hf_name, data_dir="data")

[*] adding oracle_documents_passages to OBLI_QA-generation-workshop


Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

: 

In [8]:
import bm25s
import Stemmer 
stemmer = Stemmer.Stemmer("english")

# Create the BM25 Oracle Setup

In [9]:
top_k = 10
# data_dir = "bm25_noisy_oracle_passages_oracle_documents" only for echr_qa
data_dir = "bm25_oracle_passages_oracle_documents"

from datasets import DatasetDict


create = True
def retrieve_top_passages(entry):
    query = entry['gold_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    local_top_k = top_k if top_k <= len(all_passages_text) else len(all_passages_text)
    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=local_top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  
try:
    if not create:
        new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['test'][0][f"top_{top_k}_passages"][0], indent=4)) 



[*] adding top_10_passages to OBLI_QA-generation-workshop


Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/2786 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

"1\n14. SUSPICIOUS ACTIVITY/TRANSACTION REPORTS\n14.1 Application and definitions\nIn this Chapter \"money laundering\" and \u201cterrorist financing\" means the criminal offences defined in Federal AML Legislation.\n14.2 Internal reporting requirements\n14.2.1 A Relevant Person must establish and maintain policies, procedures, systems and controls in order to monitor and detect suspicious activity or Transactions in relation to potential money laundering or terrorist financing.\n14.2.2 A Relevant Person must have policies, procedures, systems and controls to ensure that whenever any Employee, acting in the ordinary course of his employment, either:\n(a)\tknows;\n(b)\tsuspects; or\n(c)\thas reasonable grounds for knowing or suspecting,\nthat a Person is engaged in or attempting money laundering or terrorist financing, that Employee promptly notifies the Relevant Person's MLRO and provides the MLRO with all relevant details.\n14.2.3 A Relevant Person must have policies and procedures to

# Create the BM25 Relevant Setup

In [10]:
top_k = 10
data_dir = "bm25_relevant_passages_oracle_documents"

create = True
def retrieve_top_passages(entry):
    query = entry['previous_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    local_top_k = top_k if top_k <= len(all_passages_text) else len(all_passages_text)

    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=local_top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  

try:
    if not create:
        new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    
if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['test'][0][f"top_{top_k}_passages"][0], indent=4)) 

[*] adding top_10_passages to OBLI_QA-generation-workshop


Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/2786 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/709 [00:00<?, ?B/s]

"1\n9. AML/TFS COMPLIANCE AND THIRD PARTIES\n9.1 Reliance on a third party\n9.1.1\n9.1.1.(1) A Relevant Person may rely on the following third parties to conduct one or more of the elements of CDD on its behalf:\n(a)\tan Authorised Person or Recognised Body;\n(b)\ta law firm, notary, or other independent legal business, accounting firm, audit firm or insolvency practitioner or an equivalent Person in another jurisdiction;\n(c)\ta Financial Institution;\n(d)\ta member of the Relevant Person's Group; or\n(e)\tother specialised utilities for the provision of outsourced AML/TFS services.\n9.1.1.(2) In (1), a Relevant Person may rely on the information previously obtained by a third party which covers one or more elements of CDD.\n9.1.1.(3) Where a Relevant Person seeks to rely on a Person in (1) it may only do so if and to the extent that:\n(a)\tit immediately obtains the necessary CDD information from the third party in (1);\n(b)\tit takes adequate steps to satisfy itself that certified c

# Create the Dense Oracle Setup

# Create the Dense Setup

In [None]:
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset, DatasetDict
import torch
import json
import faiss
import numpy as np

top_k = 10

encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
tokenizer = AutoTokenizer.from_pretrained(encoder_name)
model = AutoModel.from_pretrained(encoder_name)
# model = torch.compile(model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

clean_encoder_name = encoder_name.replace("/", "_")
data_dir = "dense_oracle_passages_oracle_documents"
data_dir = f"{data_dir}/{clean_encoder_name}"

def normalize_embeddings(embeddings):
    return embeddings / torch.norm(embeddings, p=2, dim=1, keepdim=True)


def embed_texts(texts, tokenizer, model):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
    attention_mask = inputs['attention_mask']
    with torch.no_grad():
        token_embeddings = model(**inputs).last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        embeddings = sum_embeddings / sum_mask
    embeddings = normalize_embeddings(embeddings)
    return embeddings


def retrieve_top_passages_batch(entries):
    all_queries = entries['gold_text']
    all_passages_batch = entries['oracle_documents_passages']
    all_passages_text = [passage[1] for passages in all_passages_batch for passage in passages]
    query_embeddings = embed_texts(all_queries, tokenizer, model)  # Shape: (batch_size, dim)
    passage_embeddings = embed_texts(all_passages_text, tokenizer, model)  # Shape: (total_passages, dim)
    dim = passage_embeddings.size(1)
    # index = faiss.IndexFlatL2(dim)
    index = faiss.IndexFlatIP(dim)
    index.add(passage_embeddings.cpu().numpy())  # FAISS requires NumPy arrays here
    distances, indices = index.search(query_embeddings.cpu().numpy(), top_k)
    top_passages_batch = [[all_passages_text[index] for index in index_arr] for index_arr in indices]
    return {f"top_{top_k}_passages": top_passages_batch}

try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(
        lambda batch: retrieve_top_passages_batch(batch),
        batched=True,
        batch_size=4
    )
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4))

In [None]:

dataset_bm25_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_oracle_passages_oracle_documents")
print("===================================")
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_bm25_relevant_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_relevant_passages_oracle_documents")
print("===================================")
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['previous_text'])
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_dense_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="dense_oracle_passages_oracle_documents")
print("===================================")
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])


In [None]:
# from transformers import AutoTokenizer, AutoModel
# import torch
# import faiss
# import numpy as np

# encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
# tokenizer = AutoTokenizer.from_pretrained(encoder_name)
# model = AutoModel.from_pretrained(encoder_name)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)

# top_k = 10
# data_dir = "dense_oracle_passages_oracle_documents"

# def embed_passages(passages, tokenizer, model):
#     inputs = tokenizer(passages, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
#     with torch.no_grad():
#         embeddings = model(**inputs).pooler_output
#     return embeddings.cpu().numpy()

# def retrieve_top_passages(entry):
#     query = entry['gold_text']
#     all_passages = entry['oracle_documents_passages']
#     all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    
#     passage_embeddings = embed_passages(all_passages_text, tokenizer, model)
#     query_embedding = embed_passages([query], tokenizer, model).squeeze(0)

#     dim = passage_embeddings.shape[1]
#     index = faiss.IndexFlatL2(dim)
#     index.add(passage_embeddings)
    
#     # Retrieve top-k similar passages
#     distances, indices = index.search(np.expand_dims(query_embedding, axis=0), top_k)
#     top_passages = [all_passages_text[idx] for idx in indices[0]]
    
#     return {f"top_{top_k}_passages": top_passages}

# try:
#     workshop_hf_name = f"CLERC-generation-workshop"
#     new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# except:
#     print(f"[!] {workshop_hf_name} not found in {data_dir}")
#     create = True
    
# if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
#     new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
#     new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
#     new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# else:
#     print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
# print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 