In [2]:
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional

import faiss
import pandas as pd
import torch
from datasets import Features, Sequence, Value, load_dataset

from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    HfArgumentParser,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenizer,
)


logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
def split_text(text: str, n=100, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]

def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    titles, texts = [], []
    for title, text in zip(documents["title"], documents["text"]):
        if text is not None:
            for passage in split_text(text):
                titles.append(title if title is not None else "")
                texts.append(passage)
    return {"title": titles, "text": texts}


# def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
#     """Compute the DPR embeddings of document passages"""
#     input_ids = ctx_tokenizer(
#         documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
#     )["input_ids"]
#     embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
#     return {"embeddings": embeddings.detach().cpu().numpy()}

In [3]:
dataset = load_dataset("csv", data_files=["../data/my_knowledge_dataset.csv"], split="train",
                       delimiter="\t", column_names=["title", "text"])
# Then split the documents into passages of 100 words
# dataset = dataset.map(split_documents, batched=True)
dataset

Dataset({
    features: ['title', 'text'],
    num_rows: 2
})

#### Embed the Context Dataset

In [6]:
# dpr_ctx_encoder_model_name = "facebook/dpr-ctx_encoder-multiset-base"
# ctx_encoder = DPRContextEncoder.from_pretrained(dpr_ctx_encoder_model_name).to(device=device)
# ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(dpr_ctx_encoder_model_name)
# new_features = Features(
#     {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
# )  # optional, save as float32 instead of float64 to save space
# dataset = dataset.map(
#     partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
#     batched=True,
#     batch_size=16,
#     features=new_features,
# )

In [7]:
from sentence_transformers import SentenceTransformer
dim = 768
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=dim)



In [8]:
def embed(batch):
    """
    adds a column to the dataset called 'embeddings'
    """
    # or you can combine multiple columns here
    # For example the title and the text
    combined_text = []
    for title, text in zip(batch['title'], batch['text']):
        combined_text.append(' [SEP] '.join([title, text]))
    print(combined_text)
    return {"embeddings" : ST.encode(combined_text)}

In [9]:
dataset = dataset.map(embed, batched=True, batch_size=16)

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

['Aaron [SEP] Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother\'s spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to th

#### Index the dataset

In [10]:
index = faiss.IndexHNSWFlat(dim, 128, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings")

  0%|          | 0/1 [00:00<?, ?it/s]

Dataset({
    features: ['title', 'text', 'embeddings'],
    num_rows: 2
})

#### Load RAG 

In [11]:
def search(query: str, k: int = 3):
    """a function that embeds a new query and returns the most probable results"""
    embedded_query = ST.encode(query) # embed new query
    scores, retrieved_examples = dataset.get_nearest_examples( # retrieve results
        "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
        k=k # get only top k results
    )
    return scores, retrieved_examples

In [19]:
scores, results = search("snake")
for i in range(len(scores)):
    print(f"Score: {scores[i]}: Title: {results['title'][i]}, Text: {results['text'][i]}")

Score: 245.41488647460938: Title: Pokémon, Text: Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The or

In [13]:
import os
hf_token = os.environ["HF_TOKEN"]
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
rag_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(rag_model_name, token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(rag_model_name, token=hf_token)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [29]:
SYS_PROMPT = """You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer."""
def format_prompt(prompt,retrieved_documents):
    """using the retrieved documents we will prompt the model to generate our responses"""
    PROMPT = f"Question:{prompt}\nContext:"
    for idx in range(len(retrieved_documents)) :
        PROMPT+= f"{retrieved_documents['text'][0]}\n"
    return PROMPT

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

def generate(formatted_prompt):
    formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
    # messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
    # tell the model to generate
    input_ids = tokenizer(formatted_prompt,
        return_tensors="pt")["input_ids"]
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    return tokenizer.decode(response, skip_special_tokens=True)

In [30]:
# question = "What does Moses' rod turn into ?"
# input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
# generated = model.generate(input_ids)
# generated_string = tokenizer.batch_decode(generated, skip_special_token=True)[0]
# print(f"Q: {question}")
# print(f"A: {generated_string}")

In [31]:
formatted_prompt = format_prompt("What does Moses' rod turn into ?", results)
res = generate(formatted_prompt)
print(res)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


KeyboardInterrupt: 

In [39]:
from pydantic import BaseModel, constr, Field
import outlines
from outlines.models import Transformers
import torch
from typing import Optional, Tuple
from enum import Enum

class Answer(BaseModel):
    answer: str
    reason: str

class Response(BaseModel):
    answers: Optional[Tuple[Answer]] = Field(
        default=None,
        description="A list of answers"
    )
    
outlines_model = Transformers(model, tokenizer)
generator = outlines.generate.json(outlines_model, Response)
rng = torch.Generator(device="cuda")
rng.manual_seed(42)

res = generator(formatted_prompt)
print(repr(res))

Compiling FSM index for all state transitions: 100%|████████████████████████████████████| 67/67 [00:03<00:00, 19.80it/s]


Response(answers=(Answer(answer='Staff', reason="According to the biblical account, as recorded in the book of Exodus, Moses' rod transforms into a serpent when he throws it down in his contest with the Egyptian magicians (Exodus 7:8-13; 8:5-7)."),))


### Semantic Search

In [3]:
enrich_table = pd.read_csv("../data/enrich_tbl.csv")
enrich_table.shape

(679, 5)

In [4]:
df = enrich_table.copy()
df.drop(columns=["ID", "Adjusted P-value", "Genes"], inplace=True)
df.head()

Unnamed: 0,Term,Desc
0,Cytoplasmic Translation,The chemical reactions and pathways resulting ...
1,Macromolecule Biosynthetic Process,GO
2,Translation,GO
3,Peptide Biosynthetic Process,The chemical reactions and pathways resulting ...
4,Gene Expression,The process in which a genes sequence is conve...


In [21]:
from datasets import Dataset
torch.set_grad_enabled(False)
dataset = Dataset.from_pandas(df)
dataset

Dataset({
    features: ['Term', 'Desc'],
    num_rows: 679
})

In [22]:
def concat_text(batch):
    term = batch["Term"] if batch["Term"] is not None else ""
    desc = batch["Desc"] if batch["Desc"] is not None else ""
    return {
        "text": term
                + "\n "
                + desc
    }
dataset = dataset.map(concat_text)
dataset

Map:   0%|          | 0/679 [00:00<?, ? examples/s]

Dataset({
    features: ['Term', 'Desc', 'text'],
    num_rows: 679
})

In [23]:
from sentence_transformers import SentenceTransformer
embed_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = dataset.map(
    lambda x: {"embeddings": embed_model.encode(x["text"])}
)
dataset.select(range(10)).map(lambda x: print(f"Text: {x['text']}\n, emb dim: {len(x['embeddings'])}"))

Map:   0%|          | 0/679 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Text: Cytoplasmic Translation 
 The chemical reactions and pathways resulting in the formation of a protein in the cytoplasm. This is a ribosome-mediated process in which the information in messenger RNA (mRNA) is used to specify the sequence of amino acids in the protein.
, emb dim: 1024
Text: Macromolecule Biosynthetic Process 
 GO
, emb dim: 1024
Text: Translation 
 GO
, emb dim: 1024
Text: Peptide Biosynthetic Process 
 The chemical reactions and pathways resulting in the formation of peptides, compounds of 2 or more (but usually less than 100) amino acids where the alpha carboxyl group of one is bound to the alpha amino group of another. This may include the translation of a precursor protein and its subsequent processing into a functional peptide.
, emb dim: 1024
Text: Gene Expression 
 The process in which a genes sequence is converted into a mature gene product (protein or RNA). This includes the production of an RNA transcript and its processing, as well as translation and mat

Dataset({
    features: ['Term', 'Desc', 'text', 'embeddings'],
    num_rows: 10
})

In [32]:
dataset.add_faiss_index(column="embeddings")
query = "obesity"
# query = "A biologist is studying the causal relationship between SNP rs1421085 and obesity."
embedded_query = embed_model.encode([query])
embedded_query.shape

  0%|          | 0/1 [00:00<?, ?it/s]

(1, 1024)

In [33]:
scores, docs = dataset.get_nearest_examples("embeddings", embedded_query)
for score, doc in zip(scores, docs['Term']):
    print(f"Score: {score}, Doc: {doc}")

Score: 249.88510131835938, Doc: Carbohydrate Catabolic Process 
Score: 250.58338928222656, Doc: Negative Regulation Of Metabolic Process 
Score: 265.1728210449219, Doc: Negative Regulation Of Protein Metabolic Process 
Score: 274.4952697753906, Doc: Negative Regulation Of Inflammatory Response 
Score: 274.9875793457031, Doc: Pyruvate Metabolic Process 
Score: 276.41705322265625, Doc: Positive Regulation Of Metabolic Process 
Score: 276.830810546875, Doc: Adiponectin-Activated Signaling Pathway 
Score: 279.9755859375, Doc: Intestinal Lipid Absorption 
Score: 282.984375, Doc: Negative Regulation Of Protein Catabolic Process 
Score: 283.5004577636719, Doc: Intestinal Cholesterol Absorption 


In [30]:
docs

{'Term': ['Adiponectin-Activated Signaling Pathway ',
  'mRNA Catabolic Process ',
  'mRNA Metabolic Process ',
  'Positive Regulation Of Protein Metabolic Process ',
  'Positive Regulation Of Protein Localization To Chromosome, Telomeric Region ',
  'Negative Regulation Of Protein Metabolic Process ',
  'DNA Metabolic Process ',
  'Regulation Of Fat Cell Differentiation ',
  'Regulation Of Adipose Tissue Development ',
  'Integrated Stress Response Signaling '],
 'Desc': ['The series of molecular signals initiated by adiponectin binding to its receptor on the surface of a cell, and ending with the regulation of a downstream cellular process, e.g. transcription.',
  'The chemical reactions and pathways resulting in the breakdown of mRNA, messenger RNA, which is responsible for carrying the coded genetic message, transcribed from DNA, to sites of protein assembly at the ribosomes.',
  'The chemical reactions and pathways involving mRNA, messenger RNA, which is responsible for carrying t

In [27]:
dataset.get_nearest_examples("embeddings", embedded_query)

NearestExamplesResults(scores=array([288.51056, 289.72144, 291.4478 , 295.29388, 298.4854 , 298.66528,
       302.63095, 305.89227, 309.58698, 310.03046], dtype=float32), examples={'Term': ['Adiponectin-Activated Signaling Pathway ', 'mRNA Catabolic Process ', 'mRNA Metabolic Process ', 'Positive Regulation Of Protein Metabolic Process ', 'Positive Regulation Of Protein Localization To Chromosome, Telomeric Region ', 'Negative Regulation Of Protein Metabolic Process ', 'DNA Metabolic Process ', 'Regulation Of Fat Cell Differentiation ', 'Regulation Of Adipose Tissue Development ', 'Integrated Stress Response Signaling '], 'Desc': ['The series of molecular signals initiated by adiponectin binding to its receptor on the surface of a cell, and ending with the regulation of a downstream cellular process, e.g. transcription.', 'The chemical reactions and pathways resulting in the breakdown of mRNA, messenger RNA, which is responsible for carrying the coded genetic message, transcribed from 