# Retrieval Augmented Generation for Medical Question-Answering with Llama-2-7b

In this notebook, we harness the Llama-2-7b model for medical question-answering. Using the Retrieval Augmented Generation (RAG) approach, we merge the inherent knowledge of Llama-2-7b with a curated medical knowledge base. Our strategy incorporates dense and sparse embeddings from MiniLM and Splade, respectively, facilitating precise, relevant, and non-hallucinated responses. With the aid of SageMaker and Pinecone, we present a seamless pipeline for embedding, retrieval, and insightful text generation.

In [None]:
!pip install -qU \
    sagemaker==2.173.0 \
    pinecone-client==2.2.1 \
    ipywidgets==7.0.0

[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

## Deploy Llama-2-7b in SageMaker JumpStart

Meta developed and publicly released the Llama 2 family of large language models (LLMs), a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama-2-Chat, are optimized for dialogue use cases. Llama-2-Chat models outperform open-source chat models on most benchmarks we tested, and in our human evaluations for helpfulness and safety, are on par with some popular closed-source models like ChatGPT and PaLM. We provide a detailed description of our approach to fine-tuning and safety improvements of Llama-2-Chat in order to enable the community to build on our work and contribute to the responsible development of LLMs.

Model Developers Meta AI

Variations Llama 2 comes in a range of parameter sizes — 7B, 13B, and 70B — as well as pretrained and fine-tuned variations.

Input Models input text only.

Output Models generate text only.

Model Architecture Llama 2 is an auto-regressive language optimized transformer. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align to human preferences for helpfulness and safety.

***
This model supports the following inference payload parameters:

* **max_new_tokens:** Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer.
* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.
* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.
* **return_full_text:** If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False.

You may specify any subset of the parameters mentioned above while invoking an endpoint.

**NOTE**: If `max_new_tokens` is not defined, the model may generate up to the maximum total tokens allowed, which is 4K for these models. This may result in endpoint query timeout errors, so it is recommended to set `max_new_tokens` when possible. For 7B, 13B, and 70B models, we recommend to set `max_new_tokens` no greater than 1500, 1000, and 500 respectively, while keeping the total number of tokens less than 4K.

**NOTE**: In order to support a 4k context length, this model has restricted query payloads to only utilize a batch size of 1. Payloads with larger batch sizes will receive an endpoint error prior to inference.
***

## Ask a question to LLM with and without providing the context

In [None]:
import json
import boto3

In [None]:
def create_payload(question):
    prompts = [question]
    payloads = []
    for prompt in prompts:
        payloads.append(
            {
                "inputs": prompt,
                "parameters": {"max_new_tokens": 10, "top_p": 0.9, "temperature": 0.3, "return_full_text": False},
            }
        )
    return payloads[0]

endpoint_name = 'jumpstart-dft-meta-textgeneration-llama-2-7b'

def query_llama2_7b_endpoint(payload):
    client = boto3.client("sagemaker-runtime")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes="accept_eula=true",
    )
    response = response["Body"].read().decode("utf8")
    response = json.loads(response)
    return response[0]['generation']

In [None]:
question = "Which lace plant produces perforations in its leaves through PCD?"
payload = create_payload(question)
query_response = query_llama2_7b_endpoint(payload)
print(payload["inputs"])
print(query_response)

Which lace plant produces perforations in its leaves through PCD?

A. Mimosa pudica



In [None]:
context = """Programmed cell death (PCD) is the regulated death of cells within an organism.
The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD.
The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing
areoles. PCD occurs in the cells at the center of these areoles and progresses outwards,
stopping approximately five cells from the vasculature.  The possible importance of mitochondrial
permeability transition pore (PTP) formation during PCD was indirectly examined via in vivo
cyclosporine A (CsA) treatment. This treatment resulted in lace plant leaves with a significantly
lower number of perforations compared to controls, and that displayed mitochondrial dynamics
similar to that of non-PCD cells."""

In [None]:
prompt_template = """Answer the following QUESTION based on the CONTEXT given.

CONTEXT:
{context}

QUESTION:
{question}

ANSWER:
"""

text_input = prompt_template.replace("{context}", context).replace("{question}", question)
payload = create_payload(text_input)
query_response = query_llama2_7b_endpoint(payload)
query_response

'Aponogeton madagascariensis\n'

## Dataset Preparation

In [None]:
!pip install -qU datasets

[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

In [None]:
from datasets import load_dataset

pubmed = load_dataset(
    'pubmed_qa',
    'pqa_labeled',
    split='train'
)
pubmed

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

Downloading and preparing dataset pubmed_qa/pqa_labeled to /root/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924...


A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

Dataset pubmed_qa downloaded and prepared to /root/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924. Subsequent calls will reuse this data.


Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 1000
})

In [None]:
pubmed[0]['pubid'], pubmed[0]['context']

(21645374,
 {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
   'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondr

In [None]:
limit = 384

def chunker(contexts: list):
    chunks = []
    all_contexts = ' '.join(contexts).split('.')
    chunk = []
    for context in all_contexts:
        chunk.append(context)
        if len(chunk) >= 3 and len('.'.join(chunk)) > limit:
            chunks.append('.'.join(chunk).strip()+'.')
            chunk = chunk[-2:]
    if chunk is not None:
        chunks.append('.'.join(chunk))
    return chunks

chunks = chunker(pubmed[0]['context']['contexts'])

ids = []
for i in range(len(chunks)):
    ids.append(f"{pubmed[0]['pubid']}-{i}")

data = []
for record in pubmed:
    chunks = chunker(record['context']['contexts'])
    for i, context in enumerate(chunks):
        data.append({
            'id': f"{record['pubid']}-{i}",
            'context': context
        })

data[:2]

[{'id': '21645374-0',
  'context': 'Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature.'},
 {'id': '21645374-1',
  'context': 'The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A.'}]

## Embeddings

### Dense Vectors

In [None]:
hub_config = {
    'HF_MODEL_ID': 'sentence-transformers/all-MiniLM-L6-v2',
    'HF_TASK': 'feature-extraction'
}

huggingface_model = HuggingFaceModel(
    env=hub_config,
    role=role,
    transformers_version="4.6",
    pytorch_version="1.7",
    py_version="py36",
)

In [None]:
encoder = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.t2.large",
    endpoint_name="minilm-demo"
)

-----!

In [None]:
out = encoder.predict({"inputs": ["some text here", "some more text goes here too"]})

In [None]:
len(out[0][0])

384

### Sparse Vectors

In [None]:
!pip install -qU git+https://github.com/naver/splade.git
!pip install torch

Collecting torch
  Downloading torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl (887.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m644.2 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hCollecting nvidia-cudnn-cu11==8.5.0.96 (from torch)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting nvidia-cublas-cu11==11.10.3.66 (from torch)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device != 'cuda':
    print("==========\n"+
          "WARNING: You are not running on GPU so this may be slow.\n")




In [None]:
from splade.models.transformer_rep import Splade

sparse_model_id = 'naver/splade-cocondenser-ensembledistil'

sparse_model = Splade(sparse_model_id, agg='max')
sparse_model.to(device)
sparse_model.eval()

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

A Jupyter Widget

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)
tokens = tokenizer(data[0]['context'], return_tensors='pt')

with torch.no_grad():
    sparse_emb = sparse_model(
        d_kwargs=tokens.to(device)
    )['d_rep'].squeeze()
sparse_emb.shape



torch.Size([30522])

In [None]:
indices = sparse_emb.nonzero().squeeze().cpu().tolist()
print(len(indices))

174


In [None]:
values = sparse_emb[indices].cpu().tolist()
sparse = {'indices': indices, 'values': values}

## Indexing

In [None]:
import pinecone

def builder(records: list):
    ids = [x['id'] for x in records]
    contexts = [x['context'] for x in records]
    dense_contexts = {"inputs": contexts}
    dense_embeddings = encoder.predict(dense_contexts)
    dense_vecs = np.mean(np.array(dense_embeddings), axis=1)
    dense_vecs = dense_vecs.tolist()
    input_ids = tokenizer(
        contexts, return_tensors='pt',
        padding=True, truncation=True
    )
    with torch.no_grad():
        sparse_vecs = sparse_model(
            d_kwargs=input_ids.to(device)
        )['d_rep'].squeeze()
    upserts = []
    for _id, dense_vec, sparse_vec, context in zip(ids, dense_vecs, sparse_vecs, contexts):
        indices = sparse_vec.nonzero().squeeze().cpu().tolist()
        values = sparse_vec[indices].cpu().tolist()
        sparse_values = {
            "indices": indices,
            "values": values
        }
        metadata = {'context': context}
        upserts.append({
            'id': _id,
            'values': dense_vec,
            'sparse_values': sparse_values,
            'metadata': metadata
        })
    return upserts

In [None]:
from creds import creds

pinecone.init(
    api_key=creds['PINECONE_API_KEY'],
    environment=creds['PINECONE_ENV']
)

In [None]:
index_name = 'pubmed-splade'

pinecone.create_index(
    index_name,
    dimension=384,
    metric="dotproduct"
)

In [None]:
pinecone.list_indexes()

['pubmed-splade']

In [None]:
index = pinecone.Index(index_name)

In [None]:
from tqdm.auto import tqdm

batch_size = 2
for i in tqdm(range(0, 1000, batch_size)):
    index.upsert(builder(data[i:i+batch_size]))

A Jupyter Widget

In [None]:
index.describe_index_stats()

{'dimension': 384,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 1000}},
 'total_vector_count': 1000}

## Combine the retrieved documents, prompt, and question to query the LLM

In [None]:
def encode(text: str):
    dense_embeddings = encoder.predict({"inputs": [text]})
    dense_vec = np.mean(np.array(dense_embeddings), axis=1)
    dense_vec = dense_vec.tolist()
    input_ids = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        sparse_vec = sparse_model(
            d_kwargs=input_ids.to(device)
        )['d_rep'].squeeze()
    indices = sparse_vec.nonzero().squeeze().cpu().tolist()
    values = sparse_vec[indices].cpu().tolist()
    sparse_dict = {"indices": indices, "values": values}
    return dense_vec, sparse_dict

In [None]:
question = "Which lace plant produces perforations in its leaves through PCD?"
dense, sparse = encode(question)
xc = index.query(
    vector=dense,
    sparse_vector=sparse,
    top_k=2,
    include_metadata=True
)
xc

{'matches': [{'id': '21645374-0',
              'metadata': {'context': 'Programmed cell death (PCD) is the '
                                      'regulated death of cells within an '
                                      'organism. The lace plant (Aponogeton '
                                      'madagascariensis) produces perforations '
                                      'in its leaves through PCD. The leaves '
                                      'of the plant consist of a latticework '
                                      'of longitudinal and transverse veins '
                                      'enclosing areoles. PCD occurs in the '
                                      'cells at the center of these areoles '
                                      'and progresses outwards, stopping '
                                      'approximately five cells from the '
                                      'vasculature.'},
              'score': 40.5750847,
              'values':

In [None]:
context_str = xc['matches'][0]['metadata']['context'] + ' ' + xc['matches'][1]['metadata']['context']
context_str

'Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature.  The possible importance of mitochondrial permeability transition pore (PTP) formation during PCD was indirectly examined via in vivo cyclosporine A (CsA) treatment. This treatment resulted in lace plant leaves with a significantly lower number of perforations compared to controls, and that displayed mitochondrial dynamics similar to that of non-PCD cells.'

In [None]:
text_input = prompt_template.replace("{context}", context_str).replace("{question}", question)

payload = create_payload(text_input)
generated_text = query_llama2_7b_endpoint(payload)
print(f"[Input]: {question}\n[Output]: {generated_text}")

[Input]: Which lace plant produces perforations in its leaves through PCD?
[Output]: Aponogeton madagascariensis



## End to End RAG

In [None]:
def rag_query(question: str) -> str:
    dense, sparse = encode(question)
    xc = index.query(
        vector=dense,
        sparse_vector=sparse,
        top_k=2,
        include_metadata=True
    )
    context_str = xc['matches'][0]['metadata']['context'] + ' ' + xc['matches'][1]['metadata']['context']
    text_input = prompt_template.replace("{context}", context_str).replace("{question}", question)
    payload = create_payload(text_input)
    generated_text = query_llama2_7b_endpoint(payload)
    return generated_text

In [None]:
rag_query("Which lace plant produces perforations in its leaves through PCD?")

'A. Aponogeton madagascari'

In [None]:
rag_query("Is trabecular bone related to primary stability of miniscrews?")

'Yes\n\n### 1.3.'