<a href="https://colab.research.google.com/github/ort-eila/git_kundaje_annotations/blob/main/step_2_load_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Step 2 - load dataset

In [3]:
!pip install transformers datasets faiss-cpu psutil

Collecting transformers
  Downloading transformers-4.32.1-py3-none-any.whl (7.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting faiss-cpu
  Downloading faiss_cpu-1.7.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m61.6 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
source_file = 'wikipedia_details.tsv'

# Define the destination folder in Google Drive
destination_folder = '/content/drive/MyDrive/bio-llm/'

output_folder = os.path.join(destination_folder,"output")

In [None]:
import torch
from datasets import Features, Sequence, Value, load_dataset
import os

In [None]:
BATCH_SIZE = 16 #instead of 16
NUM_PROCESSES = 8
WIKI_DATASET_NAME = "wiki_dataset"

In [None]:
!rm -rf $output_folder
!mkdir $output_folder

In [None]:
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional

import torch
from datasets import Features, Sequence, Value, load_dataset

import faiss
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    HfArgumentParser,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenizer,
)


logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"


def split_text(text: str, n=100, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]


def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    titles, texts = [], []
    for title, text in zip(documents["title"], documents["text"]):
        if text is not None:
            for passage in split_text(text):
                titles.append(title if title is not None else "")
                texts.append(passage)
    return {"title": titles, "text": texts}


# model(input_ids).pooler_output

def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
    """Compute the DPR embeddings of document passages"""
    print("documents is ",documents)
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    print("input_ids is ",input_ids)
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    print("embeddings is ",embeddings)
    return {"embeddings": embeddings.detach().cpu().numpy()}




@dataclass
class RagExampleArguments:
    csv_path: str = field(
        default=str(""),
        metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
    )
    question: Optional[str] = field(
        default="What is BRCA1?",
        metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
    )
    rag_model_name: str = field(
        default="facebook/rag-sequence-nq",
        metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
    )
    dpr_ctx_encoder_model_name: str = field(
        default= "facebook/dpr-ctx_encoder-multiset-base",
        metadata={
            "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
        },
    )
    output_dir: Optional[str] = field(
        default=destination_folder,
        metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
    )


@dataclass
class ProcessingArguments:
    num_proc: Optional[int] = field(
        default=NUM_PROCESSES,
        metadata={
            "help": "The number of processes to use to split the documents into passages. Default is single process."
        },
    )
    batch_size: int = field(
        default=BATCH_SIZE,
        metadata={
            "help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
        },
    )


@dataclass
class IndexHnswArguments:
    d: int = field(
        default=768,
        metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
    )
    m: int = field(
        default=128,
        metadata={
            "help": "The number of bi-directional links created for every new element during the HNSW index construction."
        },
    )




In [None]:
#if __name__ == "__main__":

# logging.basicConfig(level=logging.WARNING)
# logger.setLevel(logging.INFO)

# In place of parser, just create arguments objects.
#parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
#rag_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()

rag_args = RagExampleArguments()
processing_args = ProcessingArguments()
index_hnsw_args = IndexHnswArguments()

rag_args.csv_path = os.path.join(destination_folder,source_file)
# "./wikipedia_details.csv"

rag_args.output_dir = output_folder

rag_args.question = "what is BRCA1?"

# Probably don't need this...
#with TemporaryDirectory() as tmp_dir:
#    rag_args.output_dir = rag_args.output_dir or tmp_dir


In [None]:
rag_args.csv_path

In [None]:
# DPR, which stands for Dense Passage Retrieval, is a method for representing passages of text in a
# dense vector space, such that similar passages are close to each other in this space.
# DPR is commonly used in information retrieval and question-answering systems to efficiently
# retrieve relevant passages given a query.

In [None]:
def print_dataset(dataset):
    if len(dataset) > 0:
        for i in range(min(5, len(dataset))):  # Print the first 5 rows or less if dataset is smaller
            print(f"Title: {dataset['title'][i]}")
            print(f"Text: {dataset['text'][i]}")
            print()
    else:
        print("The dataset is empty.")

In [None]:
!ls drive/MyDrive

In [None]:
!ls

In [None]:

######################################
logger.info("Step 1 - Create the dataset")
######################################

# The dataset needed for RAG must have three columns:
# - title (string): title of the document
# - text (string): text of a passage of the document
# - embeddings (array of dimension d): DPR representation of the passage

# Let's say you have documents in tab-separated csv files with columns "title" and "text"
assert os.path.isfile(rag_args.csv_path), "Please provide a valid path to a csv file"

# You can load a Dataset object this way
dataset = load_dataset(
    "csv", data_files=[rag_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
)

# print("csv dataset load")
# print_dataset(dataset)

# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)


# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_args.dpr_ctx_encoder_model_name)
new_features = Features(
    {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space
dataset = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched=True,
    batch_size=processing_args.batch_size,
    features=new_features,
)

print("after map")
print_dataset(dataset)

# And finally save your dataset
passages_path = os.path.join(rag_args.output_dir, WIKI_DATASET_NAME)
dataset.save_to_disk(passages_path)


In [None]:
dataset.shape

In [None]:
print_dataset(dataset[5:])

In [None]:
print_dataset(dataset[200:])

In [None]:
# #debug
# from datasets import load_from_disk
# dataset = load_from_disk(passages_path)  # to reload the dataset
# print_dataset(dataset)


In [None]:
######################################
logger.info("Step 2 - Index the dataset")
######################################

# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)

# And save the index
faiss_index_name = "wiki_index.faiss"
index_path = os.path.join(rag_args.output_dir, faiss_index_name)
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path)  # to reload the index

In [None]:
######################################
logger.info("Step 3 - Load RAG")
######################################

# Easy way to load the model
retriever = RagRetriever.from_pretrained(
    rag_args.rag_model_name, index_name="custom", indexed_dataset=dataset
)
model = RagSequenceForGeneration.from_pretrained(rag_args.rag_model_name, retriever=retriever)
tokenizer = RagTokenizer.from_pretrained(rag_args.rag_model_name)

# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)

In [None]:
######################################
logger.info("Step 4 - Have fun")
######################################

# question = "what is BRCA1 gene?"
question = "is CDKN1B a protein or a gene?"
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
generated = model.generate(input_ids)
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
logger.info("Q: " + question)
logger.info("A: " + generated_string)


In [None]:
print("Q: " + question)
print("A: " + generated_string)

In [None]:
input_ids

In [None]:
generated


In [None]:
generated_string

In [None]:
logger.info("Q: " + question)
logger.info("A: " + generated_string)

In [None]:
# TODO: how to print more data on the resources that were used to provide the answer

In [4]:
from transformers import RagRetriever, RagTokenForGeneration, RagTokenizer, RagSequenceForGeneration
from transformers import pipeline

# Initialize a RAG model with a retriever and a generator
retriever = RagRetriever.from_pretrained("facebook/rag-token-base")
generator = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
model = RagSequenceForGeneration(retriever=retriever, generator=generator)

# Define your added information
knowledge = [
    "BRCA1 is a human gene that produces a protein called BRCA1.",
    "Mutations in the BRCA1 gene can increase the risk of breast and ovarian cancer.",
    "BRCA1 mutations are inherited in an autosomal dominant manner."
]

# Add the knowledge to the retriever
retriever.index_documents(knowledge)

# Ask a question
question = "What is BRCA1?"
answer = pipeline("table-question-answering", model=model)
response = answer(question=question)

print("Q:", question)
print("A:", response[0]['answer'])


Downloading (…)lve/main/config.json:   0%|          | 0.00/4.55k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading (…)_tokenizer/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.


Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizerFast'.


Downloading builder script:   0%|          | 0.00/9.62k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/67.5k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/4.69G [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/50 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.32G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.32G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

OSError: ignored