## Dense Passage Retrieval: Retrieve Top K matching passages 

In [3]:
%%capture

!pip install cohere-sagemaker


#### Imports 

In [4]:
from cohere_sagemaker import CohereError
from cohere_sagemaker import Client
from requests.auth import HTTPBasicAuth
import requests
import logging 
import boto3
import yaml
import json
import os

##### Setup logging

In [5]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies

In [6]:
logger.info(f'Using requests=={requests.__version__}')
logger.info(f'Using pyyaml=={yaml.__version__}')

Using requests==2.28.2
Using pyyaml==6.0


#### Setup essentials

In [8]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'huggingface-textembedding-gpt-j-6b-fp16-1685703520'
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'j2-jumbo-instruct'

CHUNKS_DIR_PATH = './data/chunks'
sagemaker_client = boto3.client('runtime.sagemaker')
cohere_client = Client(endpoint_name=TEXT_GENERATION_MODEL_ENDPOINT_NAME)

In [9]:
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

es_username = config['credentials']['username']
es_password = config['credentials']['password']

domain_endpoint = config['domain']['endpoint']
domain_index = config['domain']['index']

In [10]:
URL = f'{domain_endpoint}/{domain_index}/_search'
logger.info(f'URL for Elasticsearch index = {URL}')

URL for Elasticsearch index = https://search-natwest-document-z3nshiiifsblsvej5fjq2qzwny.us-east-1.es.amazonaws.com/annual-report/_search


Refer to https://docs.aws.amazon.com/opensearch-service/latest/developerguide/knn.html for more info.

#### Encode question using SageMaker JumpStart's text embedding model endpoint

In [71]:
prompt = "What is the Full year attributable profit?"

In [73]:
payload = {'text_inputs': [prompt]}
payload = json.dumps(payload).encode('utf-8')
response = sagemaker_client.invoke_endpoint(EndpointName=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME, 
                                            ContentType='application/json', 
                                            Body=payload)
body = json.loads(response['Body'].read())
embedding = body['embedding'][0]

#### Find top k (k=3) matching passages aligned in context to the encoded question

In [74]:
K = 3

In [75]:
query = {
    'size': K,
    'query': {
        'knn': {
          'embedding': {
            'vector': embedding,
            'k': K
          }
        }
      }
    }

In [76]:
response = requests.post(URL, auth=HTTPBasicAuth(es_username, es_password), json=query)
response_json = response.json()
hits = response_json['hits']['hits']

#### Generate answers using SageMaker JumpStart's text generation model by leveraging the previously matched passages 

In [77]:
for hit in hits:
    score = hit['_score']
    passage = hit['_source']['passage']
    doc_id = hit['_source']['doc_id']
    passage_id = hit['_source']['passage_id']
    qa_prompt = f'Context={passage}\nQuestion={prompt}\nAnswer='
    
    response = cohere_client.generate(prompt=qa_prompt, 
                                      max_tokens=512, 
                                      temperature=0.25, 
                                      return_likelihoods='GENERATION')
    
    answer = response.generations[0].text.strip().replace('\n', '')
    logger.info(f'Answer:\n{answer}')
    logger.info(f'Reference:\nDocument = {doc_id} | Passage = {passage_id} | Score = {score}')
    
if not hits:
    logger.warn('No matching documents found!')

Answer:

The full year attributable profit for the year ended 31 December 2022 is 2,454 million pounds.
Reference:
Document = 047 | Passage = 1 | Score = 0.62423235
Answer:

The full year attributable profit is 3340 million pounds.
Reference:
Document = 046 | Passage = 1 | Score = 0.6240503
Answer:

The full year attributable profit is £1,475m
Reference:
Document = 045 | Passage = 1 | Score = 0.6207235
