In [15]:
from datasets import load_dataset, DatasetDict
import numpy as np
from tqdm import tqdm
import os
import json
from sentence_transformers import SentenceTransformer, util
import torch
import argparse
import bm25s

In [16]:
chunk_size = 300
chunk_overlap = 0.5

num_proc = os.cpu_count() - 3

In [17]:
dataset = "cuad"

key_field = "docid"
if dataset == "clerc":
    original_dataset = load_dataset("jhu-clsp/CLERC", data_files={"train": f"generation/train.jsonl",  "test": f"generation/test.jsonl"})
    workshop_hf_name = f"CLERC-generation-workshop"
elif dataset == "echr":
    workshop_hf_name = f"ECHR-generation-workshop"
elif dataset == "echr_qa":
    workshop_hf_name = f"ECHR_QA-generation-workshop"
elif dataset == "obli_qa":
    workshop_hf_name = f"OBLI_QA-generation-workshop"
elif dataset == "cuad":
    workshop_hf_name = f"CUAD-generation-workshop"
else:
    raise ValueError("Invalid dataset")
current_chosen_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="data")

README.md:   0%|          | 0.00/592 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/15.2M [00:00<?, ?B/s]

test-00000-of-00002.parquet:   0%|          | 0.00/46.9M [00:00<?, ?B/s]

test-00001-of-00002.parquet:   0%|          | 0.00/49.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4182 [00:00<?, ? examples/s]

In [18]:
current_chosen_dataset

DatasetDict({
    train: Dataset({
        features: ['docid', 'previous_text', 'gold_text', 'citations', 'oracle_documents_passages'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['docid', 'previous_text', 'gold_text', 'citations', 'oracle_documents_passages'],
        num_rows: 4182
    })
})

In [10]:
# def chunk_document_into_passages_overlap_split(document_id: str, text: str, max_len: int, overlap: float):
#     if max_len <= 0:
#         raise ValueError("max_len must be a positive integer")
#     if not (0 <= overlap < 1):
#         raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

#     words = text.split()
#     chunks = []
#     step = int(max_len * (1 - overlap))
#     last_index = 0
#     for index, i in enumerate(range(0, len(words), step)):
#         current_splits = [document_id]
#         if i == 0:
#             second_split = words[i:i+max_len]
#             last_index = i+max_len
#             splits = ["", " ".join(second_split)]
#         else:
#             first_split = words[last_index-step:last_index]
#             second_split = words[last_index:last_index+step]
#             last_index = last_index+step
#             splits = [" ".join(first_split), " ".join(second_split)]
#         if splits:
#             current_splits.extend(splits)
#             chunks.append(current_splits)
#     return chunks

# def chunk_citations_overlap_split(record):
#     chunks = []
#     for citation_id, citation in record['citations']:
#         chunks.extend(chunk_document_into_passages_overlap_split(citation_id, citation, chunk_size, chunk_overlap))
#     return chunks

# if 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding oracle_documents_passages_overlap_split to {workshop_hf_name}")
#     if "citations" not in current_chosen_dataset.column_names['train'] or "citations" not in current_chosen_dataset.column_names['test']:
#         current_chosen_dataset = DatasetDict({
#             'train': current_chosen_dataset['train'].add_column(name="citations", column=original_dataset['train']['citations']),
#             'test': current_chosen_dataset['test'].add_column(name="citations", column=original_dataset['test']['citations'])
#         })
#     current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages_overlap_split': chunk_citations_overlap_split(record)})
#     current_chosen_dataset = current_chosen_dataset.select_columns(['docid', 'previous_text', 'gold_text', 'citations' ,'oracle_documents_passages', 'oracle_documents_passages_overlap_split'])
#     current_chosen_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}")
# else:
#     print(f"[!] oracle_documents_passages_overlap_split already exists in {workshop_hf_name}")

In [9]:
def chunk_document_into_passages(document_id: str, text: str, max_len: int, overlap: float):
    if max_len <= 0:
        raise ValueError("max_len must be a positive integer")
    if not (0 <= overlap < 1):
        raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

    words = text.split()
    chunks = []
    step = int(max_len * (1 - overlap))
    for i in range(0, len(words), step):
        chunk_words = words[i:i+max_len]
        if chunk_words:
            chunk_text = ' '.join(chunk_words)
            chunks.append([document_id, chunk_text])
    return chunks

def chunk_citations(record):
    chunks = []
    for citation_id, citation in record['citations']:
        chunks.extend(chunk_document_into_passages(citation_id, citation, chunk_size, chunk_overlap))
    return chunks

if 'oracle_documents_passages' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages' not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding oracle_documents_passages to {workshop_hf_name}")
    current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages': chunk_citations(record)}, num_proc=num_proc)
else:
    print(f"[!] oracle_documents_passages already exists in {workshop_hf_name}")
# current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages'])
current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages', 'oracle_documents_oracle_passages'])
current_chosen_dataset.push_to_hub(workshop_hf_name)

[*] adding oracle_documents_passages to ECHR_QA-generation-workshop


Map (num_proc=45):   0%|          | 0/116 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/615 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/ylkhayat/ECHR_QA-generation-workshop/commit/b7be662d71666090f46b06c362be4597fbb7306f', commit_message='Upload dataset', commit_description='', oid='b7be662d71666090f46b06c362be4597fbb7306f', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/ylkhayat/ECHR_QA-generation-workshop', endpoint='https://huggingface.co', repo_type='dataset', repo_id='ylkhayat/ECHR_QA-generation-workshop'), pr_revision=None, pr_num=None)

# Create the BM25 Oracle Setup

In [19]:
import bm25s
import Stemmer 
stemmer = Stemmer.Stemmer("english")

In [13]:
top_k = 10
# data_dir = "bm25_noisy_oracle_passages_oracle_documents" only for echr_qa
data_dir = "bm25_oracle_passages_oracle_documents"

from datasets import DatasetDict


create = False
def retrieve_top_passages(entry):
    query = entry['gold_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    local_top_k = top_k if top_k <= len(all_passages_text) else len(all_passages_text)
    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=local_top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  
try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['test'][0][f"top_{top_k}_passages"][0], indent=4)) 

train-00000-of-00001.parquet:   0%|          | 0.00/15.5M [00:00<?, ?B/s]

test-00000-of-00002.parquet:   0%|          | 0.00/46.1M [00:00<?, ?B/s]

test-00001-of-00002.parquet:   0%|          | 0.00/48.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Failed to read file '/srv/elkhyo/.cache/huggingface/hub/datasets--ylkhayat--CUAD-generation-workshop/snapshots/f4e6f4a4d980ff57bd1a73ae039b502f31636985/bm25_oracle_passages_oracle_documents/train-00000-of-00001.parquet' with error <class 'datasets.table.CastError'>: Couldn't cast
id: string
title: string
context: string
question: string
answers: struct<text: list<element: string>, answer_start: list<element: int32>>
  child 0, text: list<element: string>
      child 0, element: string
  child 1, answer_start: list<element: int32>
      child 0, element: int32
docid: string
previous_text: string
gold_text: string
citations: list<element: list<element: string>>
  child 0, element: list<element: string>
      child 0, element: string
oracle_documents_passages: list<element: list<element: string>>
  child 0, element: list<element: string>
      child 0, element: string
top_10_passages: list<element: string>
  child 0, element: string
-- schema metadata --
huggingface: '{"info": {"features"

[!] CUAD-generation-workshop not found in bm25_oracle_passages_oracle_documents
[*] adding top_10_passages to CUAD-generation-workshop


Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/4182 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

"REFERENCE\nConflicts between these two languages arising there from, if any, shall be subject to Chinese version. One copy for the Sellers, two copies for the Buyers. The Contract becomes effective after signed by both parties. THE BUYER: THE SELLER: SIGNATURE: SIGNATURE: 6\n\nSource: LOHA CO. LTD., F-1, 12/9/2019"


# Create the BM25 Setup

In [None]:
top_k = 10
data_dir = "bm25_relevant_passages_oracle_documents"

create = False
def retrieve_top_passages(entry):
    query = entry['previous_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    local_top_k = top_k if top_k <= len(all_passages_text) else len(all_passages_text)

    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=local_top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  

try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    
if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['test'][0][f"top_{top_k}_passages"][0], indent=4)) 

train-00000-of-00001.parquet:   0%|          | 0.00/16.5M [00:00<?, ?B/s]

test-00000-of-00002.parquet:   0%|          | 0.00/46.3M [00:00<?, ?B/s]

test-00001-of-00002.parquet:   0%|          | 0.00/48.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Failed to read file '/srv/elkhyo/.cache/huggingface/hub/datasets--ylkhayat--CUAD-generation-workshop/snapshots/39c4486a7fb5e318f3d9097dcbf18fdd385794ff/bm25_relevant_passages_oracle_documents/train-00000-of-00001.parquet' with error <class 'datasets.table.CastError'>: Couldn't cast
docid: string
previous_text: string
gold_text: string
citations: list<element: list<element: string>>
  child 0, element: list<element: string>
      child 0, element: string
oracle_documents_passages: list<element: list<element: string>>
  child 0, element: list<element: string>
      child 0, element: string
top_10_passages: list<element: string>
  child 0, element: string
-- schema metadata --
huggingface: '{"info": {"features": {"docid": {"dtype": "string", "_type"' + 469
to
{'docid': Value(dtype='string', id=None), 'previous_text': Value(dtype='string', id=None), 'gold_text': Value(dtype='string', id=None), 'citations': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None

[!] CUAD-generation-workshop not found in bm25_relevant_passages_oracle_documents
[*] adding top_10_passages to CUAD-generation-workshop


Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/4182 [00:00<?, ? examples/s]

# Create the Dense Oracle Setup

# Create the Dense Setup

In [None]:
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset, DatasetDict
import torch
import json
import faiss
import numpy as np

top_k = 10

encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
tokenizer = AutoTokenizer.from_pretrained(encoder_name)
model = AutoModel.from_pretrained(encoder_name)
# model = torch.compile(model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

clean_encoder_name = encoder_name.replace("/", "_")
data_dir = "dense_oracle_passages_oracle_documents"
data_dir = f"{data_dir}/{clean_encoder_name}"

def normalize_embeddings(embeddings):
    return embeddings / torch.norm(embeddings, p=2, dim=1, keepdim=True)


def embed_texts(texts, tokenizer, model):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
    attention_mask = inputs['attention_mask']
    with torch.no_grad():
        token_embeddings = model(**inputs).last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        embeddings = sum_embeddings / sum_mask
    embeddings = normalize_embeddings(embeddings)
    return embeddings


def retrieve_top_passages_batch(entries):
    all_queries = entries['gold_text']
    all_passages_batch = entries['oracle_documents_passages']
    all_passages_text = [passage[1] for passages in all_passages_batch for passage in passages]
    query_embeddings = embed_texts(all_queries, tokenizer, model)  # Shape: (batch_size, dim)
    passage_embeddings = embed_texts(all_passages_text, tokenizer, model)  # Shape: (total_passages, dim)
    dim = passage_embeddings.size(1)
    # index = faiss.IndexFlatL2(dim)
    index = faiss.IndexFlatIP(dim)
    index.add(passage_embeddings.cpu().numpy())  # FAISS requires NumPy arrays here
    distances, indices = index.search(query_embeddings.cpu().numpy(), top_k)
    top_passages_batch = [[all_passages_text[index] for index in index_arr] for index_arr in indices]
    return {f"top_{top_k}_passages": top_passages_batch}

try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(
        lambda batch: retrieve_top_passages_batch(batch),
        batched=True,
        batch_size=4
    )
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4))

In [None]:

dataset_bm25_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_oracle_passages_oracle_documents")
print("===================================")
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_bm25_relevant_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_relevant_passages_oracle_documents")
print("===================================")
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['previous_text'])
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_dense_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="dense_oracle_passages_oracle_documents")
print("===================================")
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])


In [None]:
# from transformers import AutoTokenizer, AutoModel
# import torch
# import faiss
# import numpy as np

# encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
# tokenizer = AutoTokenizer.from_pretrained(encoder_name)
# model = AutoModel.from_pretrained(encoder_name)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)

# top_k = 10
# data_dir = "dense_oracle_passages_oracle_documents"

# def embed_passages(passages, tokenizer, model):
#     inputs = tokenizer(passages, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
#     with torch.no_grad():
#         embeddings = model(**inputs).pooler_output
#     return embeddings.cpu().numpy()

# def retrieve_top_passages(entry):
#     query = entry['gold_text']
#     all_passages = entry['oracle_documents_passages']
#     all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    
#     passage_embeddings = embed_passages(all_passages_text, tokenizer, model)
#     query_embedding = embed_passages([query], tokenizer, model).squeeze(0)

#     dim = passage_embeddings.shape[1]
#     index = faiss.IndexFlatL2(dim)
#     index.add(passage_embeddings)
    
#     # Retrieve top-k similar passages
#     distances, indices = index.search(np.expand_dims(query_embedding, axis=0), top_k)
#     top_passages = [all_passages_text[idx] for idx in indices[0]]
    
#     return {f"top_{top_k}_passages": top_passages}

# try:
#     workshop_hf_name = f"CLERC-generation-workshop"
#     new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# except:
#     print(f"[!] {workshop_hf_name} not found in {data_dir}")
#     create = True
    
# if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
#     new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
#     new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
#     new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# else:
#     print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
# print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 