In [1]:
from datasets import load_dataset, DatasetDict
import numpy as np
from tqdm import tqdm
import os
import json
from sentence_transformers import SentenceTransformer, util
import torch
import argparse
import bm25s

In [2]:
chunk_size = 300
chunk_overlap = 0.5

num_proc = os.cpu_count() - 3

In [7]:
dataset = "echr_qa"

key_field = "docid"
if dataset == "clerc":
    original_dataset = load_dataset("jhu-clsp/CLERC", data_files={"train": f"generation/train.jsonl",  "test": f"generation/test.jsonl"})
    workshop_hf_name = f"CLERC-generation-workshop"
elif dataset == "echr":
    workshop_hf_name = f"ECHR-generation-workshop"
elif dataset == "echr_qa":
    workshop_hf_name = f"ECHR_QA-generation-workshop"
else:
    raise ValueError("Invalid dataset")
current_chosen_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}")

Downloading readme:   0%|          | 0.00/615 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/27.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/200M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/116 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [8]:
resources = current_chosen_dataset['test'].to_dict()

resources['citations'][0]

[['001-113118',
 ['001-61886',
 ['001-155662',
 ['001-114082',

In [10]:
# def chunk_document_into_passages_overlap_split(document_id: str, text: str, max_len: int, overlap: float):
#     if max_len <= 0:
#         raise ValueError("max_len must be a positive integer")
#     if not (0 <= overlap < 1):
#         raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

#     words = text.split()
#     chunks = []
#     step = int(max_len * (1 - overlap))
#     last_index = 0
#     for index, i in enumerate(range(0, len(words), step)):
#         current_splits = [document_id]
#         if i == 0:
#             second_split = words[i:i+max_len]
#             last_index = i+max_len
#             splits = ["", " ".join(second_split)]
#         else:
#             first_split = words[last_index-step:last_index]
#             second_split = words[last_index:last_index+step]
#             last_index = last_index+step
#             splits = [" ".join(first_split), " ".join(second_split)]
#         if splits:
#             current_splits.extend(splits)
#             chunks.append(current_splits)
#     return chunks

# def chunk_citations_overlap_split(record):
#     chunks = []
#     for citation_id, citation in record['citations']:
#         chunks.extend(chunk_document_into_passages_overlap_split(citation_id, citation, chunk_size, chunk_overlap))
#     return chunks

# if 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding oracle_documents_passages_overlap_split to {workshop_hf_name}")
#     if "citations" not in current_chosen_dataset.column_names['train'] or "citations" not in current_chosen_dataset.column_names['test']:
#         current_chosen_dataset = DatasetDict({
#             'train': current_chosen_dataset['train'].add_column(name="citations", column=original_dataset['train']['citations']),
#             'test': current_chosen_dataset['test'].add_column(name="citations", column=original_dataset['test']['citations'])
#         })
#     current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages_overlap_split': chunk_citations_overlap_split(record)})
#     current_chosen_dataset = current_chosen_dataset.select_columns(['docid', 'previous_text', 'gold_text', 'citations' ,'oracle_documents_passages', 'oracle_documents_passages_overlap_split'])
#     current_chosen_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}")
# else:
#     print(f"[!] oracle_documents_passages_overlap_split already exists in {workshop_hf_name}")

In [9]:
def chunk_document_into_passages(document_id: str, text: str, max_len: int, overlap: float):
    if max_len <= 0:
        raise ValueError("max_len must be a positive integer")
    if not (0 <= overlap < 1):
        raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

    words = text.split()
    chunks = []
    step = int(max_len * (1 - overlap))
    for i in range(0, len(words), step):
        chunk_words = words[i:i+max_len]
        if chunk_words:
            chunk_text = ' '.join(chunk_words)
            chunks.append([document_id, chunk_text])
    return chunks

def chunk_citations(record):
    chunks = []
    for citation_id, citation in record['citations']:
        chunks.extend(chunk_document_into_passages(citation_id, citation, chunk_size, chunk_overlap))
    return chunks

if 'oracle_documents_passages' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages' not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding oracle_documents_passages to {workshop_hf_name}")
    current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages': chunk_citations(record)}, num_proc=num_proc)
else:
    print(f"[!] oracle_documents_passages already exists in {workshop_hf_name}")
# current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages'])
current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages', 'oracle_documents_oracle_passages'])
current_chosen_dataset.push_to_hub(workshop_hf_name)

[*] adding oracle_documents_passages to ECHR_QA-generation-workshop


Map (num_proc=45):   0%|          | 0/116 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/615 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/ylkhayat/ECHR_QA-generation-workshop/commit/b7be662d71666090f46b06c362be4597fbb7306f', commit_message='Upload dataset', commit_description='', oid='b7be662d71666090f46b06c362be4597fbb7306f', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/ylkhayat/ECHR_QA-generation-workshop', endpoint='https://huggingface.co', repo_type='dataset', repo_id='ylkhayat/ECHR_QA-generation-workshop'), pr_revision=None, pr_num=None)

# Create the BM25 Oracle Setup

In [10]:
import bm25s
import Stemmer 
stemmer = Stemmer.Stemmer("english")

In [11]:
top_k = 10
data_dir = "bm25_oracle_passages_oracle_documents"
from datasets import DatasetDict


create = False
def retrieve_top_passages(entry):
    query = entry['gold_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  
try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 

Downloading readme:   0%|          | 0.00/672 [00:00<?, ?B/s]

Using the latest cached version of the dataset since ylkhayat/ECHR_QA-generation-workshop couldn't be found on the Hugging Face Hub


[!] ECHR_QA-generation-workshop not found in bm25_oracle_passages_oracle_documents
[*] adding top_10_passages to ECHR_QA-generation-workshop


Map (num_proc=45):   0%|          | 0/116 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/672 [00:00<?, ?B/s]

"001-86233\nit will respect the legislature \u2019 s judgment as to what is in the general interest unless that judgment is manifestly without reasonable foundation ... It may be noted however that this was in the context of Article 1 of Protocol No. 1, not Article 8 which concerns rights of central importance to the individual \u2019 s identity, self-determination, physical and moral integrity, maintenance of relationships with others and a settled and secure place in the community ... Where general social and economic policy considerations have arisen in the context of Article 8 itself, the scope of the margin of appreciation depends on the context of the case, with particular significance attaching to the extent of the intrusion into the personal sphere of the applicant .... 83. The procedural safeguards available to the individual will be especially material in determining whether the respondent State has, when fixing the regulatory framework, remained within its margin of apprecia

# Create the BM25 Setup

In [12]:
top_k = 10
data_dir = "bm25_relevant_passages_oracle_documents"

create = False
def retrieve_top_passages(entry):
    query = entry['previous_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  

try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    
if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 

Downloading readme:   0%|          | 0.00/785 [00:00<?, ?B/s]

Using the latest cached version of the dataset since ylkhayat/ECHR_QA-generation-workshop couldn't be found on the Hugging Face Hub


[!] ECHR_QA-generation-workshop not found in bm25_relevant_passages_oracle_documents
[*] adding top_10_passages to ECHR_QA-generation-workshop


Map (num_proc=45):   0%|          | 0/116 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/785 [00:00<?, ?B/s]

"001-216764\nenjoyment of fundamental or \u201cintimate\u201d rights. This is the case in particular for Article 8 rights, which are rights of central importance to the individual\u2019s identity, self-determination, physical and moral integrity, maintenance of relationships with others and a settled and secure place in the community. (iii) The procedural safeguards available to the individual will be especially material in determining whether the respondent State has remained within its margin of appreciation. In particular, the Court must examine whether the decision-making process leading to measures of interference was fair and such as to afford due respect to the interests safeguarded to the individual by Article 8. The \u201cnecessary in a democratic society\u201d requirement under Article 8 \u00a7 2 raises a question of procedure as well of substance. (iv) Since the loss of one\u2019s home is a most extreme form of interference with the right under Article 8 to respect for one\u

# Create the Dense Oracle Setup

# Create the Dense Setup

In [None]:
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset, DatasetDict
import torch
import json
import faiss
import numpy as np

top_k = 10

encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
tokenizer = AutoTokenizer.from_pretrained(encoder_name)
model = AutoModel.from_pretrained(encoder_name)
# model = torch.compile(model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

clean_encoder_name = encoder_name.replace("/", "_")
data_dir = "dense_oracle_passages_oracle_documents"
data_dir = f"{data_dir}/{clean_encoder_name}"

def normalize_embeddings(embeddings):
    return embeddings / torch.norm(embeddings, p=2, dim=1, keepdim=True)


def embed_texts(texts, tokenizer, model):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
    attention_mask = inputs['attention_mask']
    with torch.no_grad():
        token_embeddings = model(**inputs).last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        embeddings = sum_embeddings / sum_mask
    embeddings = normalize_embeddings(embeddings)
    return embeddings


def retrieve_top_passages_batch(entries):
    all_queries = entries['gold_text']
    all_passages_batch = entries['oracle_documents_passages']
    all_passages_text = [passage[1] for passages in all_passages_batch for passage in passages]
    query_embeddings = embed_texts(all_queries, tokenizer, model)  # Shape: (batch_size, dim)
    passage_embeddings = embed_texts(all_passages_text, tokenizer, model)  # Shape: (total_passages, dim)
    dim = passage_embeddings.size(1)
    # index = faiss.IndexFlatL2(dim)
    index = faiss.IndexFlatIP(dim)
    index.add(passage_embeddings.cpu().numpy())  # FAISS requires NumPy arrays here
    distances, indices = index.search(query_embeddings.cpu().numpy(), top_k)
    top_passages_batch = [[all_passages_text[index] for index in index_arr] for index_arr in indices]
    return {f"top_{top_k}_passages": top_passages_batch}

try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(
        lambda batch: retrieve_top_passages_batch(batch),
        batched=True,
        batch_size=4
    )
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4))

In [None]:

dataset_bm25_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_oracle_passages_oracle_documents")
print("===================================")
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_bm25_relevant_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_relevant_passages_oracle_documents")
print("===================================")
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['previous_text'])
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_dense_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="dense_oracle_passages_oracle_documents")
print("===================================")
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])


In [None]:
# from transformers import AutoTokenizer, AutoModel
# import torch
# import faiss
# import numpy as np

# encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
# tokenizer = AutoTokenizer.from_pretrained(encoder_name)
# model = AutoModel.from_pretrained(encoder_name)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)

# top_k = 10
# data_dir = "dense_oracle_passages_oracle_documents"

# def embed_passages(passages, tokenizer, model):
#     inputs = tokenizer(passages, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
#     with torch.no_grad():
#         embeddings = model(**inputs).pooler_output
#     return embeddings.cpu().numpy()

# def retrieve_top_passages(entry):
#     query = entry['gold_text']
#     all_passages = entry['oracle_documents_passages']
#     all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    
#     passage_embeddings = embed_passages(all_passages_text, tokenizer, model)
#     query_embedding = embed_passages([query], tokenizer, model).squeeze(0)

#     dim = passage_embeddings.shape[1]
#     index = faiss.IndexFlatL2(dim)
#     index.add(passage_embeddings)
    
#     # Retrieve top-k similar passages
#     distances, indices = index.search(np.expand_dims(query_embedding, axis=0), top_k)
#     top_passages = [all_passages_text[idx] for idx in indices[0]]
    
#     return {f"top_{top_k}_passages": top_passages}

# try:
#     workshop_hf_name = f"CLERC-generation-workshop"
#     new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# except:
#     print(f"[!] {workshop_hf_name} not found in {data_dir}")
#     create = True
    
# if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
#     new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
#     new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
#     new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# else:
#     print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
# print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 