<a href="https://colab.research.google.com/github/ort-eila/git_kundaje_annotations/blob/main/step_2_load_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Step 2 - load dataset

In [33]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [34]:
source_file = 'wikipedia_details.csv'

# Define the destination folder in Google Drive
destination_folder = '/content/drive/MyDrive/bio-llm/'

output_folder = os.path.join(destination_folder,"output")

In [35]:
# !pip install transformers datasets faiss-cpu psutil

In [36]:
import torch
from datasets import Features, Sequence, Value, load_dataset
import os

In [37]:
BATCH_SIZE = 1 #instead of 16

In [38]:
!rm -rf $output_folder
!mkdir $output_folder

In [39]:
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional

import torch
from datasets import Features, Sequence, Value, load_dataset

import faiss
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    HfArgumentParser,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenizer,
)


logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"


def split_text(text: str, n=100, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]


def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    titles, texts = [], []
    for title, text in zip(documents["title"], documents["text"]):
        if text is not None:
            for passage in split_text(text):
                titles.append(title if title is not None else "")
                texts.append(passage)
    return {"title": titles, "text": texts}


def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
    """Compute the DPR embeddings of document passages"""
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    return {"embeddings": embeddings.detach().cpu().numpy()}




@dataclass
class RagExampleArguments:
    csv_path: str = field(
        default=str(""),
        metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
    )
    question: Optional[str] = field(
        default=None,
        metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
    )
    rag_model_name: str = field(
        default="facebook/rag-sequence-nq",
        metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
    )
    dpr_ctx_encoder_model_name: str = field(
        default="facebook/dpr-ctx_encoder-multiset-base",
        metadata={
            "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
        },
    )
    output_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
    )


@dataclass
class ProcessingArguments:
    num_proc: Optional[int] = field(
        default=None,
        metadata={
            "help": "The number of processes to use to split the documents into passages. Default is single process."
        },
    )
    batch_size: int = field(
        default=BATCH_SIZE,
        metadata={
            "help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
        },
    )


@dataclass
class IndexHnswArguments:
    d: int = field(
        default=768,
        metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
    )
    m: int = field(
        default=128,
        metadata={
            "help": "The number of bi-directional links created for every new element during the HNSW index construction."
        },
    )




In [40]:
# source_file = 'wikipedia_details.csv'

# # Define the destination folder in Google Drive
# destination_folder = '/content/drive/MyDrive/bio-llm/'

In [46]:
#if __name__ == "__main__":

# logging.basicConfig(level=logging.WARNING)
# logger.setLevel(logging.INFO)

# In place of parser, just create arguments objects.
#parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
#rag_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()

rag_args = RagExampleArguments()
processing_args = ProcessingArguments()
index_hnsw_args = IndexHnswArguments()

rag_args.csv_path = os.path.join(destination_folder,source_file)
# "./wikipedia_details.csv"

rag_args.output_dir = output_folder

rag_args.question = "what is BRCA1"

# Probably don't need this...
#with TemporaryDirectory() as tmp_dir:
#    rag_args.output_dir = rag_args.output_dir or tmp_dir


In [47]:
rag_args.csv_path

'/content/drive/MyDrive/bio-llm/wikipedia_details.csv'

In [48]:
# DPR, which stands for Dense Passage Retrieval, is a method for representing passages of text in a
# dense vector space, such that similar passages are close to each other in this space.
# DPR is commonly used in information retrieval and question-answering systems to efficiently
# retrieve relevant passages given a query.

In [49]:
def print_dataset(dataset):
    if len(dataset) > 0:
        for i in range(min(5, len(dataset))):  # Print the first 5 rows or less if dataset is smaller
            print(f"Title: {dataset['title'][i]}")
            print(f"Text: {dataset['text'][i]}")
            print()
    else:
        print("The dataset is empty.")

In [50]:
!ls drive/MyDrive

 3129_vector_image_row_1023_col301.jpg	'pza-ydds-vkk - Aug 24, 2023.gjam'
 bio-llm				'scRNA-seq results kb count'
'Colab Notebooks'			 Seminar
 ENCODE_annotation			'Untitled document.gdoc'
 HW2.ipynb


In [51]:

######################################
logger.info("Step 1 - Create the dataset")
######################################

# The dataset needed for RAG must have three columns:
# - title (string): title of the document
# - text (string): text of a passage of the document
# - embeddings (array of dimension d): DPR representation of the passage

# Let's say you have documents in tab-separated csv files with columns "title" and "text"
assert os.path.isfile(rag_args.csv_path), "Please provide a valid path to a csv file"

# You can load a Dataset object this way
dataset = load_dataset(
    "csv", data_files=[rag_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
)

# print("csv dataset load")
# print_dataset(dataset)

# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)

# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_args.dpr_ctx_encoder_model_name)
new_features = Features(
    {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space
dataset = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched=True,
    batch_size=processing_args.batch_size,
    features=new_features,
)

print("after map")
print_dataset(dataset)

# And finally save your dataset
passages_path = os.path.join(rag_args.output_dir, "my_knowledge_dataset")
dataset.save_to_disk(passages_path)


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.weight', 'ctx_encoder.bert_model.pooler.dense.bias']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.


after map
The dataset is empty.


Saving the dataset (0/1 shards): 0 examples [00:00, ? examples/s]

In [52]:
# Value("float32")

In [54]:
# Sequence("A BRCA mutation is a mutation in either of the BRCA1 and BRCA2 genes, which are tumour suppressor genes. Hundreds of different types of mutations in these genes have been identified, some of which have been determined to be harmful, while others have no proven impact. Harmful mutations in these genes may produce a hereditary breast–ovarian cancer syndrome in affected persons. Only 5–10% of breast cancer cases in women are attributed to BRCA1 and BRCA2 mutations (with BRCA1 mutations being slightly more common than BRCA2 mutations), but the impact on women with the gene mutation is more profound.[2] Women with harmful mutations in either BRCA1 or BRCA2 have a risk of breast cancer that is about five times the normal risk, and a risk of ovarian cancer that is about ten to thirty times normal.[3] The risk of breast and ovarian cancer is higher for women with a high-risk BRCA1 mutation than with a BRCA2 mutation. Having a high-risk mutation does not guarantee that the woman will develop any type of cancer, or imply that any cancer that appears was actually caused by the mutation, rather than some other factor.High-risk mutations, which disable an important error-free DNA repair process (homology directed repair), significantly increase the person's risk of developing breast cancer, ovarian cancer and certain other cancers. Why BRCA1 and BRCA2 mutations lead preferentially to cancers of the breast and ovary is not known, but lack of BRCA1 function seems to lead to non-functional X-chromosome inactivation. Not all mutations are high-risk; some appear to be harmless variations.  The cancer risk associated with any given mutation varies significantly and depends on the exact type and location of the mutation and possibly other individual factors.Mutations can be inherited from either parent and may be passed on to both sons and daughters.  Each child of a genetic carrier, regardless of sex, has a 50% chance of inheriting the mutated gene from the parent who carries the mutation. As a result, half of the people with BRCA gene mutations are male, who would then pass the mutation on to 50% of their offspring, male or female.  The risk of BRCA-related breast cancers for men with the mutation is higher than for other men, but still low.[4] However, BRCA mutations can increase the risk of other cancers, such as colon cancer, pancreatic cancer, and prostate cancer.Methods to diagnose the likelihood of a patient with mutations in BRCA1 and BRCA2 getting cancer were covered by patents ")

In [55]:
# Sequence(Value("float32"))

In [56]:
#debug
from datasets import load_from_disk
dataset = load_from_disk(passages_path)  # to reload the dataset
print_dataset(dataset)


The dataset is empty.


In [None]:
######################################
logger.info("Step 2 - Index the dataset")
######################################

# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)

# And save the index
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path)  # to reload the index

ValueError: Columns ['embeddings'] not in the dataset. Current columns in the dataset: ['title', 'text']