### 1. Set Variables

In [None]:
bucket_name = ''
prefix = ''
model_id = "meta-textgeneration-llama-2-70b-f"
instance_type = "ml.g5.2xlarge"
instance_count = 1

### 2. Install Dependencies

In [None]:
!pip3 install opensearch-py --quiet
!pip3 install requests_aws4auth --quiet
!pip3 install langchain --quiet

In [None]:
import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.jumpstart.model import JumpStartModel
from requests_aws4auth import AWS4Auth
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, helpers
from langchain.document_loaders import DirectoryLoader, UnstructuredFileLoader, PyPDFLoader, Docx2txtLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import random
import string

In [None]:
s3client = boto3.client('s3')
ssm_client = boto3.client('ssm')

In [None]:
sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

### Run Transform 

In [None]:
import time
import json
import logging
from typing import List
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

logger = logging.getLogger(__name__)

# extend the SagemakerEndpointEmbeddings class from langchain to provide a custom embedding function
class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(
        self, texts: List[str], chunk_size: int = 5
    ) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
        st = time.time()
        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i:i + _chunk_size])
            results.extend(response)
        time_taken = time.time() - st
        logger.info(f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}")
        return results


# class for serializing/deserializing requests/responses to/from the embeddings model
class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:

        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode('utf-8') 

    def transform_output(self, output: bytes) -> str:

        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        if len(embeddings) == 1:
            return [embeddings[0]]
        return embeddings
    

def create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name: str, aws_region: str) -> SagemakerEndpointEmbeddingsJumpStart:
    # all set to create the objects for the ContentHandler and 
    # SagemakerEndpointEmbeddingsJumpStart classes
    content_handler = ContentHandler()

    # note the name of the LLM Sagemaker endpoint, this is the model that we would
    # be using for generating the embeddings
    embeddings = SagemakerEndpointEmbeddingsJumpStart( 
        endpoint_name=embeddings_model_endpoint_name,
        region_name=aws_region, 
        content_handler=content_handler
    )
    return embeddings

In [None]:

loader = DirectoryLoader()

text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size=args.chunk_size_for_doc_split,
    chunk_overlap=args.chunk_overlap_for_doc_split,
    length_function=len,
)

# Stage one: read all the docs, split them into chunks. 
logger.info('Loading documents ...')
docs = loader.load()

# add a custom metadata field, such as timestamp
for doc in docs:
    doc.metadata['timestamp'] = time.time()
    doc.metadata['embeddings_model'] = args.embeddings_model_endpoint_name
chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs])

embeddings = create_sagemaker_embeddings_from_js_model(args.embeddings_model_endpoint_name, args.aws_region)


### Create a OpenSearch Index

Here, we use OpenSearch as a Vector Store. The first step is to create an Index.

In [None]:
access_key = ssm_client.get_parameter(Name='AccessKey')['Parameter']['Value']
secret_key = ssm_client.get_parameter(Name='SecretAccessKey')['Parameter']['Value']
host = ssm_client.get_parameter(Name='OpenSearchHost')['Parameter']['Value']

In [None]:
service = 'aoss'

INDEX_NAME = 'sm_docs_' + ''.join(random.choices(string.ascii_lowercase, k=8))
VECTOR_FIELD = 'vectors'

awsauth = AWS4Auth(access_key, secret_key,
                   aws_region, service)# session_token=credentials.token)

# Create the OpenSearch client
aoss_client = OpenSearch(
        hosts=[{'host': host, 'port': 443}],
        http_auth=awsauth,
        use_ssl=True,
        verify_certs=True,
        ssl_assert_hostname = False,
        ssl_show_warn = False,
        connection_class=RequestsHttpConnection,
        timeout=300
    )

##Delete the index if exists
#response = aoss_client.indices.delete(
#    index = INDEX_NAME
#)

#Create the index
aoss_client.indices.create(INDEX_NAME, 
    body={
        "settings":{
            "index.knn": True
        },
        "mappings":{
            "properties": {
                "vectors": {
                    "type": "knn_vector", 
                    "dimension": 1536 # dimension of the embedding vector
                },
            }
        }
    }
)

In [None]:
for i in docs:
    # The text data of each chunk
    exampleContent = i.page_content
    # Generating the embeddings for each chunk of text data
    exampleInput = json.dumps({"inputText": exampleContent})
    exampleVectors = embeddings.embed_query(exampleInput)

    # setting the text data as the text variable, and generated vector to a vector variable
    text = exampleContent
    vectors = exampleVectors
    
    indexDocument = {VECTOR_FIELD: vectors,'text': text}
   
    response = aoss_client.index(
        index=INDEX_NAME,
        body=indexDocument,
        refresh=False
    )