<a href="https://colab.research.google.com/github/ort-eila/git_kundaje_annotations/blob/main/step_2_load_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [84]:
# Step 2 - load dataset

In [85]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [86]:
source_file = 'wikipedia_details.tsv'

# Define the destination folder in Google Drive
destination_folder = '/content/drive/MyDrive/bio-llm/'

output_folder = os.path.join(destination_folder,"output")

In [87]:
# !pip install transformers datasets faiss-cpu psutil

In [88]:
import torch
from datasets import Features, Sequence, Value, load_dataset
import os

In [89]:
BATCH_SIZE = 1 #instead of 16

In [90]:
!rm -rf $output_folder
!mkdir $output_folder

In [91]:
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional

import torch
from datasets import Features, Sequence, Value, load_dataset

import faiss
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    HfArgumentParser,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenizer,
)


logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"


def split_text(text: str, n=100, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]


def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    titles, texts = [], []
    for title, text in zip(documents["title"], documents["text"]):
        if text is not None:
            for passage in split_text(text):
                titles.append(title if title is not None else "")
                texts.append(passage)
    return {"title": titles, "text": texts}


# model(input_ids).pooler_output

def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
    """Compute the DPR embeddings of document passages"""
    print("documents is ",documents)
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    return {"embeddings": embeddings.detach().cpu().numpy()}




@dataclass
class RagExampleArguments:
    csv_path: str = field(
        default=str(""),
        metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
    )
    question: Optional[str] = field(
        default=None,
        metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
    )
    rag_model_name: str = field(
        default="facebook/rag-sequence-nq",
        metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
    )
    dpr_ctx_encoder_model_name: str = field(
        default= "facebook/dpr-ctx_encoder-single-nq-base", #"facebook/dpr-ctx_encoder-multiset-base",
        metadata={
            "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
        },
    )
    output_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
    )


@dataclass
class ProcessingArguments:
    num_proc: Optional[int] = field(
        default=None,
        metadata={
            "help": "The number of processes to use to split the documents into passages. Default is single process."
        },
    )
    batch_size: int = field(
        default=BATCH_SIZE,
        metadata={
            "help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
        },
    )


@dataclass
class IndexHnswArguments:
    d: int = field(
        default=768,
        metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
    )
    m: int = field(
        default=128,
        metadata={
            "help": "The number of bi-directional links created for every new element during the HNSW index construction."
        },
    )




In [92]:
# source_file = 'wikipedia_details.csv'

# # Define the destination folder in Google Drive
# destination_folder = '/content/drive/MyDrive/bio-llm/'

In [93]:
#if __name__ == "__main__":

# logging.basicConfig(level=logging.WARNING)
# logger.setLevel(logging.INFO)

# In place of parser, just create arguments objects.
#parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
#rag_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()

rag_args = RagExampleArguments()
processing_args = ProcessingArguments()
index_hnsw_args = IndexHnswArguments()

rag_args.csv_path = os.path.join(destination_folder,source_file)
# "./wikipedia_details.csv"

rag_args.output_dir = output_folder

rag_args.question = "what is BRCA1"

# Probably don't need this...
#with TemporaryDirectory() as tmp_dir:
#    rag_args.output_dir = rag_args.output_dir or tmp_dir


In [94]:
rag_args.csv_path

'/content/drive/MyDrive/bio-llm/wikipedia_details.tsv'

In [95]:
# DPR, which stands for Dense Passage Retrieval, is a method for representing passages of text in a
# dense vector space, such that similar passages are close to each other in this space.
# DPR is commonly used in information retrieval and question-answering systems to efficiently
# retrieve relevant passages given a query.

In [96]:
def print_dataset(dataset):
    if len(dataset) > 0:
        for i in range(min(5, len(dataset))):  # Print the first 5 rows or less if dataset is smaller
            print(f"Title: {dataset['title'][i]}")
            print(f"Text: {dataset['text'][i]}")
            print()
    else:
        print("The dataset is empty.")

In [97]:
!ls drive/MyDrive

 3129_vector_image_row_1023_col301.jpg	'pza-ydds-vkk - Aug 24, 2023.gjam'
 bio-llm				'scRNA-seq results kb count'
'Colab Notebooks'			 Seminar
 ENCODE_annotation			'Untitled document.gdoc'
 HW2.ipynb


In [98]:
# from transformers import DPRContextEncoder, DPRContextEncoderTokenizer

# tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
# model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
# input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
# embeddings = model(input_ids).pooler_output
# embeddings

In [99]:

######################################
logger.info("Step 1 - Create the dataset")
######################################

# The dataset needed for RAG must have three columns:
# - title (string): title of the document
# - text (string): text of a passage of the document
# - embeddings (array of dimension d): DPR representation of the passage

# Let's say you have documents in tab-separated csv files with columns "title" and "text"
assert os.path.isfile(rag_args.csv_path), "Please provide a valid path to a csv file"

# You can load a Dataset object this way
dataset = load_dataset(
    "csv", data_files=[rag_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
)

# print("csv dataset load")
# print_dataset(dataset)

# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)

# And compute the embeddings
# tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base") #ctx_tokenizer
# model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base") #ctx_encoder
# input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
# embeddings = model(input_ids).pooler_output
# embeddings

# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_args.dpr_ctx_encoder_model_name)
new_features = Features(
    {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space
dataset = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched=True,
    batch_size=processing_args.batch_size,
    features=new_features,
)


print("after map")
print_dataset(dataset)

# And finally save your dataset
passages_path = os.path.join(rag_args.output_dir, "my_knowledge_dataset")
dataset.save_to_disk(passages_path)


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/15 [00:00<?, ? examples/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.weight', 'ctx_encoder.bert_model.pooler.dense.bias']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokeniz

Map:   0%|          | 0/294 [00:00<?, ? examples/s]

documents is  {'title': ['title'], 'text': ['text']}
documents is  {'title': ['BRCA_mutation'], 'text': ['A BRCA mutation is a mutation in either of the BRCA1 and BRCA2 genes, which are tumour suppressor genes. Hundreds of different types of mutations in these genes have been identified, some of which have been determined to be harmful, while others have no proven impact. Harmful mutations in these genes may produce a hereditary breast–ovarian cancer syndrome in affected persons. Only 5–10% of breast cancer cases in women are attributed to BRCA1 and BRCA2 mutations (with BRCA1 mutations being slightly more common than BRCA2 mutations), but the impact on women with the gene mutation is more profound.[2] Women with']}
documents is  {'title': ['BRCA_mutation'], 'text': ['harmful mutations in either BRCA1 or BRCA2 have a risk of breast cancer that is about five times the normal risk, and a risk of ovarian cancer that is about ten to thirty times normal.[3] The risk of breast and ovarian ca

Saving the dataset (0/1 shards):   0%|          | 0/294 [00:00<?, ? examples/s]

In [100]:
# #debug
# from datasets import load_from_disk
# dataset = load_from_disk(passages_path)  # to reload the dataset
# print_dataset(dataset)


In [103]:
######################################
logger.info("Step 2 - Index the dataset")
######################################

# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)

# And save the index
faiss_index_name = "wiki_index.faiss"
index_path = os.path.join(rag_args.output_dir, faiss_index_name)
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path)  # to reload the index

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
######################################
logger.info("Step 3 - Load RAG")
######################################

# Easy way to load the model
retriever = RagRetriever.from_pretrained(
    rag_args.rag_model_name, index_name="custom", indexed_dataset=dataset
)
model = RagSequenceForGeneration.from_pretrained(rag_args.rag_model_name, retriever=retriever)
tokenizer = RagTokenizer.from_pretrained(rag_args.rag_model_name)

# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

Downloading pytorch_model.bin:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

In [None]:
######################################
logger.info("Step 4 - Have fun")
######################################

question = rag_example_args.question or "what is BRCA1 ?"
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
generated = model.generate(input_ids)
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
logger.info("Q: " + question)
logger.info("A: " + generated_string)