## Encode legal passages and create embeddings index

##### Prerequisites

In [3]:
%%capture 

!pip install PyYAML

#### Imports

In [4]:
from requests.auth import HTTPBasicAuth
from tqdm import tqdm
import requests
import logging 
import boto3
import yaml
import json
import os

##### Setup logging

In [5]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [6]:
logger.info(f'Using requests=={requests.__version__}')
logger.info(f'Using pyyaml=={yaml.__version__}')

Using requests==2.28.2
Using pyyaml==6.0


#### Setup essentials

In [7]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'huggingface-textembedding-gpt-j-6b-fp16-1685703520'
CHUNKS_DIR_PATH = './data/chunks'

sagemaker_client = boto3.client('runtime.sagemaker')

In [8]:
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

es_username = config['credentials']['username']
es_password = config['credentials']['password']

domain_endpoint = config['domain']['endpoint']
domain_index = config['domain']['index']

In [9]:
URL = f'{domain_endpoint}/{domain_index}'
logger.info(f'URL for Elasticsearch index = {URL}')

URL for Elasticsearch index = https://search-natwest-document-z3nshiiifsblsvej5fjq2qzwny.us-east-1.es.amazonaws.com/annual-report


#### Define the index mapping with a k-NN vector field

In [10]:
mapping = {
    'settings': {
        'index': {
            'knn': True  # Enable k-NN search for this index
        }
    },
    'mappings': {
        'properties': {
            'embedding': {  # k-NN vector field
                'type': 'knn_vector',
                'dimension': 4096  # Dimension of the vector
            },
            'passage_id': {
                'type': 'long'
            },
            'passage': {
                'type': 'text'
            },
            'doc_id': {
                'type': 'keyword'
            }
        }
    }
}

#### Create the index with the specified mapping

In [11]:
# Check if the index exists using an HTTP HEAD request
response = requests.head(URL, auth=HTTPBasicAuth(es_username, es_password))

# If the index does not exist (status code 404), create the index
if response.status_code == 404:
    response = requests.put(URL, auth=HTTPBasicAuth(es_username, es_password), json=mapping)
    logger.info(f'Index created: {response.text}')
else:
    logger.error('Index already exists!')

Index created: {"acknowledged":true,"shards_acknowledged":true,"index":"annual-report"}


#### Encode passages (chunks) using JumpStart's GPT-J text embedding model and ingest to OpenSearch

In [12]:
def chunk_iterator(dir_path: str):
    for root, _, filenames in os.walk(dir_path):
        for filename in filenames:
            file_path = os.path.join(root, filename)
            if os.path.isfile(file_path):
                with open(file_path, 'r') as file:
                    file_contents = file.read()
                    yield filename, file_contents

In [13]:
%%time

i = 1
for chunk_name, chunk in tqdm(chunk_iterator(CHUNKS_DIR_PATH)):
    doc_id, chunk_id = chunk_name.split('_')
    payload = {'text_inputs': [chunk]}
    payload = json.dumps(payload).encode('utf-8')
    
    response = sagemaker_client.invoke_endpoint(EndpointName=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME, 
                                                ContentType='application/json',  
                                                Body=payload)
    
    model_predictions = json.loads(response['Body'].read())
    embedding = model_predictions['embedding'][0]
   
    document = { 
        'doc_id': doc_id, 
        'passage_id': chunk_id,
        'passage': chunk, 
        'embedding': embedding}
    
    response = requests.post(f'{URL}/_doc/{i}', auth=HTTPBasicAuth(es_username, es_password), json=document)
    i += 1
    
    if response.status_code not in [200, 201]:
        logger.error(response.status_code)
        logger.error(response.text)
        break

44it [00:14,  2.94it/s]

CPU times: user 1.1 s, sys: 71.9 ms, total: 1.17 s
Wall time: 15 s



