In [1]:
from datasets import load_dataset, DatasetDict
import numpy as np
from tqdm import tqdm
import os
import json
from sentence_transformers import SentenceTransformer, util
import torch
import argparse
import bm25s

In [2]:
chunk_size = 300
chunk_overlap = 0.5

num_proc = os.cpu_count() - 3

In [5]:
dataset = "echr_qa"

key_field = "docid"
if dataset == "clerc":
    original_dataset = load_dataset("jhu-clsp/CLERC", data_files={"train": f"generation/train.jsonl",  "test": f"generation/test.jsonl"})
    workshop_hf_name = f"CLERC-generation-workshop"
elif dataset == "echr":
    workshop_hf_name = f"ECHR-generation-workshop"
elif dataset == "echr_qa":
    workshop_hf_name = f"ECHR_QA-generation-workshop"
else:
    raise ValueError("Invalid dataset")
current_chosen_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}")

Downloading readme:   0%|          | 0.00/494 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/211M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1116 [00:00<?, ? examples/s]

In [6]:
resources = current_chosen_dataset['test'].to_dict()

resources['citations'][0]

[['001-45580',
  '. The following is an outline of the case, as submitted to the European Commission of Human Rights, and of the procedure before the Commission. A. The Application\n. The applicants are Irish citizens, members of the same family, and reside in Belfast, Northern Ireland. The first applicant was born in 1938 and she is a housewife. The second applicant, born in 1935, is her husband. The third applicant, born in 1964, is her son. The fourth and fifth applicants are her eldest twin daughters, born in 1967. The sixth applicant is her youngest daughter, born in 1970. The applicants were represented before the Commission by Messrs. Madden and Finucane, Solicitors, Belfast.\n. The application is directed against the United Kingdom. The respondent Government were represented by their Agents, Mrs. A. Glover and Mr. H. Llewellyn, both of the Foreign and Commonwealth Office.\n. The case concerns the entry into the applicants\' home by an army team early one morning in 1982, the su

In [10]:
# def chunk_document_into_passages_overlap_split(document_id: str, text: str, max_len: int, overlap: float):
#     if max_len <= 0:
#         raise ValueError("max_len must be a positive integer")
#     if not (0 <= overlap < 1):
#         raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

#     words = text.split()
#     chunks = []
#     step = int(max_len * (1 - overlap))
#     last_index = 0
#     for index, i in enumerate(range(0, len(words), step)):
#         current_splits = [document_id]
#         if i == 0:
#             second_split = words[i:i+max_len]
#             last_index = i+max_len
#             splits = ["", " ".join(second_split)]
#         else:
#             first_split = words[last_index-step:last_index]
#             second_split = words[last_index:last_index+step]
#             last_index = last_index+step
#             splits = [" ".join(first_split), " ".join(second_split)]
#         if splits:
#             current_splits.extend(splits)
#             chunks.append(current_splits)
#     return chunks

# def chunk_citations_overlap_split(record):
#     chunks = []
#     for citation_id, citation in record['citations']:
#         chunks.extend(chunk_document_into_passages_overlap_split(citation_id, citation, chunk_size, chunk_overlap))
#     return chunks

# if 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages_overlap_split' not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding oracle_documents_passages_overlap_split to {workshop_hf_name}")
#     if "citations" not in current_chosen_dataset.column_names['train'] or "citations" not in current_chosen_dataset.column_names['test']:
#         current_chosen_dataset = DatasetDict({
#             'train': current_chosen_dataset['train'].add_column(name="citations", column=original_dataset['train']['citations']),
#             'test': current_chosen_dataset['test'].add_column(name="citations", column=original_dataset['test']['citations'])
#         })
#     current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages_overlap_split': chunk_citations_overlap_split(record)})
#     current_chosen_dataset = current_chosen_dataset.select_columns(['docid', 'previous_text', 'gold_text', 'citations' ,'oracle_documents_passages', 'oracle_documents_passages_overlap_split'])
#     current_chosen_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}")
# else:
#     print(f"[!] oracle_documents_passages_overlap_split already exists in {workshop_hf_name}")

In [11]:
def chunk_document_into_passages(document_id: str, text: str, max_len: int, overlap: float):
    if max_len <= 0:
        raise ValueError("max_len must be a positive integer")
    if not (0 <= overlap < 1):
        raise ValueError("overlap must be between 0 (inclusive) and 1 (exclusive)")

    words = text.split()
    chunks = []
    step = int(max_len * (1 - overlap))
    for i in range(0, len(words), step):
        chunk_words = words[i:i+max_len]
        if chunk_words:
            chunk_text = ' '.join(chunk_words)
            chunks.append([document_id, chunk_text])
    return chunks

def chunk_citations(record):
    chunks = []
    for citation_id, citation in record['citations']:
        chunks.extend(chunk_document_into_passages(citation_id, citation, chunk_size, chunk_overlap))
    return chunks

if 'oracle_documents_passages' not in current_chosen_dataset.column_names['train'] or 'oracle_documents_passages' not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding oracle_documents_passages to {workshop_hf_name}")
    current_chosen_dataset = current_chosen_dataset.map(lambda record: {'oracle_documents_passages': chunk_citations(record)}, num_proc=num_proc)
else:
    print(f"[!] oracle_documents_passages already exists in {workshop_hf_name}")
# current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages'])
current_chosen_dataset = current_chosen_dataset.select_columns([key_field, 'previous_text', 'gold_text', 'citations' , 'oracle_documents_passages', 'oracle_documents_oracle_passages'])
current_chosen_dataset.push_to_hub(workshop_hf_name)

[*] adding oracle_documents_passages to ECHR-generation-workshop


Map (num_proc=45):   0%|          | 0/5000 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/521 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/ylkhayat/ECHR-generation-workshop/commit/a3b8b935570af329bda488250b54272ff082a76d', commit_message='Upload dataset', commit_description='', oid='a3b8b935570af329bda488250b54272ff082a76d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/ylkhayat/ECHR-generation-workshop', endpoint='https://huggingface.co', repo_type='dataset', repo_id='ylkhayat/ECHR-generation-workshop'), pr_revision=None, pr_num=None)

# Create the BM25 Oracle Setup

In [12]:
import bm25s
import Stemmer 
stemmer = Stemmer.Stemmer("english")

In [13]:
top_k = 10
data_dir = "bm25_oracle_passages_oracle_documents"
from datasets import DatasetDict


create = False
def retrieve_top_passages(entry):
    query = entry['gold_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  
try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 

Downloading readme:   0%|          | 0.00/595 [00:00<?, ?B/s]

Using the latest cached version of the dataset since ylkhayat/ECHR-generation-workshop couldn't be found on the Hugging Face Hub


[!] ECHR-generation-workshop not found in bm25_oracle_passages_oracle_documents
[*] adding top_10_passages to ECHR-generation-workshop


Map (num_proc=45):   0%|          | 0/5000 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/595 [00:00<?, ?B/s]

"74969/01\ncontact with all the persons concerned. It follows from these considerations that the Court's task is not to substitute itself for the domestic authorities in the exercise of their responsibilities regarding custody and access issues, but rather to review, in the light of the Convention, the decisions taken by those authorities in the exercise of their power of appreciation (see Sahin and Sommerfeld v. Germany [GC], nos. 30943/96 and 31871/96, \u00a7 64 and \u00a7 62 respectively, ECHR 2003- VIII, and T.P. and K.M. v. the United Kingdom [GC], no. 28945/95, \u00a7 71, ECHR 2001-V ). 42. The margin of appreciation to be accorded to the competent national authorities will vary in accordance with the nature of the issues and the importance of the interests at stake. In particular when deciding on custody, the Court has recognised that the authorities enjoy a wide margin of appreciation. However, a stricter scrutiny is called for as regards any further limitations, such as restri

# Create the BM25 Setup

In [14]:
top_k = 10
data_dir = "bm25_relevant_passages_oracle_documents"

create = False
def retrieve_top_passages(entry):
    query = entry['previous_text']
    all_passages = entry['oracle_documents_passages']
    all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    corpus_tokens = bm25s.tokenize(all_passages_text, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    query_tokens = bm25s.tokenize(query, stemmer=stemmer)
    results, _ = retriever.retrieve(query_tokens, corpus=all_passages_text, k=top_k)
    results = results.squeeze(0)
    return {f"top_{top_k}_passages": results}  

try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    
if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 

Downloading readme:   0%|          | 0.00/708 [00:00<?, ?B/s]

Using the latest cached version of the dataset since ylkhayat/ECHR-generation-workshop couldn't be found on the Hugging Face Hub


[!] ECHR-generation-workshop not found in bm25_relevant_passages_oracle_documents
[*] adding top_10_passages to ECHR-generation-workshop


Map (num_proc=45):   0%|          | 0/5000 [00:00<?, ? examples/s]

Map (num_proc=45):   0%|          | 0/1000 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/708 [00:00<?, ?B/s]

"74969/01\nTHIRD SECTION CASE OF G\u00d6RG\u00dcL\u00dc v. GERMANY (Application no. 74969/01) FINAL 26/05/2004 JUDGMENT This version was rectified in accordance with Rule 81 of the Rules of Court on 24 May 2005 STRASBOURG 26 February 2004 This judgment will become final in the circumstances set out in Article 44 \u00a7 2 of the Convention. It may be subject to editorial revision. In the case of G\u00f6rg\u00fcl\u00fc v. Germany, The European Court of Human Rights (Third Section), sitting as a Chamber composed of: Mr L. Caflisch, President, Mr G. Ress, Mr P. K\u016bris, Mr B. Zupan\u010di\u010d, Mr J. Hedigan, Mrs M. Tsatsa-Nikolovska, Mr K. Traja, judges, and Mr V. Berger, Section Registrar, Having deliberated in private on 20 March 2003 and 5 February 2004, Delivers the following judgment, which was adopted on the last \u2011 mentioned date: PROCEDURE 1. The case originated in an application (no. 74969/01) against the Federal Republic of Germany lodged with the Court under Article 34 

# Create the Dense Oracle Setup

# Create the Dense Setup

In [None]:
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset, DatasetDict
import torch
import json
import faiss
import numpy as np

top_k = 10

encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
tokenizer = AutoTokenizer.from_pretrained(encoder_name)
model = AutoModel.from_pretrained(encoder_name)
# model = torch.compile(model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

clean_encoder_name = encoder_name.replace("/", "_")
data_dir = "dense_oracle_passages_oracle_documents"
data_dir = f"{data_dir}/{clean_encoder_name}"

def normalize_embeddings(embeddings):
    return embeddings / torch.norm(embeddings, p=2, dim=1, keepdim=True)


def embed_texts(texts, tokenizer, model):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
    attention_mask = inputs['attention_mask']
    with torch.no_grad():
        token_embeddings = model(**inputs).last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        embeddings = sum_embeddings / sum_mask
    embeddings = normalize_embeddings(embeddings)
    return embeddings


def retrieve_top_passages_batch(entries):
    all_queries = entries['gold_text']
    all_passages_batch = entries['oracle_documents_passages']
    all_passages_text = [passage[1] for passages in all_passages_batch for passage in passages]
    query_embeddings = embed_texts(all_queries, tokenizer, model)  # Shape: (batch_size, dim)
    passage_embeddings = embed_texts(all_passages_text, tokenizer, model)  # Shape: (total_passages, dim)
    dim = passage_embeddings.size(1)
    # index = faiss.IndexFlatL2(dim)
    index = faiss.IndexFlatIP(dim)
    index.add(passage_embeddings.cpu().numpy())  # FAISS requires NumPy arrays here
    distances, indices = index.search(query_embeddings.cpu().numpy(), top_k)
    top_passages_batch = [[all_passages_text[index] for index in index_arr] for index_arr in indices]
    return {f"top_{top_k}_passages": top_passages_batch}

try:
    new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
except:
    print(f"[!] {workshop_hf_name} not found in {data_dir}")
    create = True
    

if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
    print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
    new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
    new_dataset = new_dataset.map(
        lambda batch: retrieve_top_passages_batch(batch),
        batched=True,
        batch_size=4
    )
    new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
else:
    print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4))

In [None]:

dataset_bm25_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_oracle_passages_oracle_documents")
print("===================================")
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_bm25_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_bm25_relevant_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="bm25_relevant_passages_oracle_documents")
print("===================================")
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['previous_text'])
print(dataset_bm25_relevant_passages_oracle_documents['train'][0]['top_10_passages'][0])

dataset_dense_oracle_passages_oracle_documents = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir="dense_oracle_passages_oracle_documents")
print("===================================")
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['gold_text'])
print(dataset_dense_oracle_passages_oracle_documents['train'][0]['top_10_passages'][0])


In [None]:
# from transformers import AutoTokenizer, AutoModel
# import torch
# import faiss
# import numpy as np

# encoder_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
# tokenizer = AutoTokenizer.from_pretrained(encoder_name)
# model = AutoModel.from_pretrained(encoder_name)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)

# top_k = 10
# data_dir = "dense_oracle_passages_oracle_documents"

# def embed_passages(passages, tokenizer, model):
#     inputs = tokenizer(passages, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings).to(device)
#     with torch.no_grad():
#         embeddings = model(**inputs).pooler_output
#     return embeddings.cpu().numpy()

# def retrieve_top_passages(entry):
#     query = entry['gold_text']
#     all_passages = entry['oracle_documents_passages']
#     all_passages_text = [f"{passage_arr[0]}\n{passage_arr[1]}" for passage_arr in all_passages]
    
#     passage_embeddings = embed_passages(all_passages_text, tokenizer, model)
#     query_embedding = embed_passages([query], tokenizer, model).squeeze(0)

#     dim = passage_embeddings.shape[1]
#     index = faiss.IndexFlatL2(dim)
#     index.add(passage_embeddings)
    
#     # Retrieve top-k similar passages
#     distances, indices = index.search(np.expand_dims(query_embedding, axis=0), top_k)
#     top_passages = [all_passages_text[idx] for idx in indices[0]]
    
#     return {f"top_{top_k}_passages": top_passages}

# try:
#     workshop_hf_name = f"CLERC-generation-workshop"
#     new_dataset = load_dataset(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# except:
#     print(f"[!] {workshop_hf_name} not found in {data_dir}")
#     create = True
    
# if create or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['train'] or f"top_{top_k}_passages" not in current_chosen_dataset.column_names['test']:
#     print(f"[*] adding top_{top_k}_passages to {workshop_hf_name}")
#     new_dataset = DatasetDict({split: current_chosen_dataset[split] for split in current_chosen_dataset.keys()})
#     new_dataset = new_dataset.map(retrieve_top_passages, num_proc=num_proc)
#     new_dataset.push_to_hub(f"ylkhayat/{workshop_hf_name}", data_dir=data_dir)
# else:
#     print(f"[!] top_{top_k}_passages already exists in {workshop_hf_name}")
# print(json.dumps(new_dataset['train'][0][f"top_{top_k}_passages"][0], indent=4)) 