<a href="https://colab.research.google.com/github/shannn1/goodRAG/blob/main/embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import os
from torch.cuda.amp import autocast
from huggingface_hub import login
from datasets import DatasetDict

device = "cuda" if torch.cuda.is_available() else "cpu"

def split_text(text: str, n=512, sep=" "):
    words = text.split(sep)
    passages = [sep.join(words[i : i + n]).strip() for i in range(0, len(words), n)]
    return passages

def split_documents(example, chunk_size=512):
    titles, docs = example["title"], example["document"]
    new_titles, new_docs = [], []
    for title, doc in zip(titles, docs):
        if doc is None:
            continue
        passages = split_text(doc, n=chunk_size)
        for p in passages:
            new_titles.append(title)
            new_docs.append(p)
    return {"title": new_titles, "document": new_docs}

dataset = load_dataset("lighteval/natural_questions_clean")

train_data = dataset["train"].select_columns(["id", "title", "document"])
test_data = dataset["validation"].select_columns(["id", "title", "document"])

train_data = train_data.map(split_documents, batched=True, remove_columns=train_data.column_names)
test_data = test_data.map(split_documents, batched=True, remove_columns=test_data.column_names)

ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

def compute_embeddings(documents, ctx_encoder, ctx_tokenizer):
    inputs = ctx_tokenizer(
        documents["title"],
        documents["document"],
        truncation=True,
        max_length=512,
        padding="longest",
        return_tensors="pt"
    )
    input_ids = inputs["input_ids"].to(device)

    with torch.no_grad():
        with autocast():
            embeddings = ctx_encoder(input_ids, return_dict=True).pooler_output

    del input_ids, inputs
    torch.cuda.empty_cache()
    return embeddings.detach().cpu().numpy()

def process_and_save(data, output_dir, batch_size, split_name):
    os.makedirs(output_dir, exist_ok=True)
    num_shards = (len(data) + batch_size - 1) // batch_size
    for shard_index in range(num_shards):
        start = shard_index * batch_size
        end = min(start + batch_size, len(data))
        batch = data.select(range(start, end))

        while True:
            try:
                print(f"Processing shard {shard_index + 1}/{num_shards} for {split_name}.")
                print(f"Currently allocated GPU memory: {torch.cuda.memory_allocated(device) / 1e6:.2f} MB")

                batch_with_embeddings = batch.map(
                    lambda b: {"embeddings": compute_embeddings(b, ctx_encoder, ctx_tokenizer).tolist()},
                    batched=True,
                    batch_size=batch_size,
                )
                batch_with_embeddings.save_to_disk(os.path.join(output_dir, f"{split_name}_batch_{shard_index}.dataset"))
                del batch_with_embeddings
                break
            except RuntimeError as e:
                if "CUDA out of memory" in str(e):
                    batch_size = max(batch_size // 2, 1)
                    print(f"Reduced batch size to {batch_size} due to OOM.")
                else:
                    raise e
            except Exception as e:
                print(f"Unexpected error while processing batch: {e}")
                raise e

def upload_batches_to_hub(batches, dataset_name, token, batch_merge_size=10):
    login(token=token)

    try:
        existing_dataset = load_dataset(dataset_name)
        print(f"Loaded existing dataset '{dataset_name}' from Hugging Face Hub.")
    except FileNotFoundError:
        existing_dataset = None
        print(f"Dataset '{dataset_name}' not found on Hugging Face Hub. Creating a new one.")

    for i in range(0, len(batches), batch_merge_size):
        merged_batches = concatenate_datasets(batches[i:i + batch_merge_size])

        try:
            if existing_dataset is None:
                dataset_dict = DatasetDict({"train": merged_batches})
                dataset_dict.push_to_hub(dataset_name)
                print(f"Uploaded initial batches {i}-{i + batch_merge_size - 1} to Hugging Face Hub as '{dataset_name}'.")
                existing_dataset = dataset_dict
            else:
                combined_dataset = DatasetDict({
                    "train": concatenate_datasets([existing_dataset["train"], merged_batches])
                })
                combined_dataset.push_to_hub(dataset_name)
                print(f"Appended batches {i}-{i + batch_merge_size - 1} to '{dataset_name}' on Hugging Face Hub.")
                existing_dataset = combined_dataset
        except Exception as e:
            print(f"Error while uploading batches {i}-{i + batch_merge_size - 1} to Hugging Face Hub: {e}")
            raise e

train_batch_size = 1024
test_batch_size = 1024

process_and_save(train_data, output_dir="./train_batches", batch_size=train_batch_size, split_name="train")
process_and_save(test_data, output_dir="./test_batches", batch_size=test_batch_size, split_name="test")

train_batches = [
    Dataset.load_from_disk(os.path.join("./train_batches", f))
    for f in os.listdir("./train_batches") if f.endswith(".dataset")
]
test_batches = [
    Dataset.load_from_disk(os.path.join("./test_batches", f))
    for f in os.listdir("./test_batches") if f.endswith(".dataset")
]

upload_batches_to_hub(train_batches + test_batches, "knowledge_base_genai", token=" ")