## SBERT endpoint

In [65]:
!pip install transformers



In [92]:
import logging
import json
import boto3
import io
import os
import time
import logging
import sagemaker
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import IdentitySerializer

from scipy.spatial import distance
# import torch

In [67]:
# connect to SageMaker
try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
    
print(f"sagemaker role arn: {role}")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker role arn: arn:aws:iam::571667364805:role/service-role/AmazonSageMaker-ExecutionRole-20231103T080028


## Parameters

In [87]:
# sm-llm-aws, sm-gec-aws, sm-cc-aws
ENDPOINT_NAME = "sm-cc-aws"
#ENDPOINT_NAME = 'huggingface-pytorch-inference-2023-11-11-18-18-08-839'

In [88]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
sagemaker_session = sagemaker.Session()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


In [89]:
"""
FunctionName: invoke_endpoint
Input: transcript_item (sentence), label_map
    transcript_item type: string
    label_map type: dict
Output: Question
    type: string
"""
# @tracer.capture_method
def invoke_endpoint(payload, endpoint_name):
    runtime = boto3.client('runtime.sagemaker')
    response = runtime.invoke_endpoint(EndpointName=endpoint_name,
                                      ContentType="application/json",
                                      Body=json.dumps(payload).encode())
                                       # Body = json.loads(json.dumps(payload)))
                                     
    embeddings = json.loads((response["Body"].read()))
    return embeddings

# @tracer.capture_lambda_handler
def lambda_handler(event):
    start = time.time()
    similarity_scores = invoke_endpoint(event, ENDPOINT_NAME)
    end = time.time()
    logger.info(f"Profiling: \n Getting Embeddings: {1000*(end-start)} milliseconds")   
    return similarity_scores

In [78]:
# json_event =  {
#    "inputs": "Do you want a cup of coffee?"
# }

# embedding_user_answer = lambda_handler(json_event)

In [79]:
# json_event =  {
#    "inputs": "No. can i have a cup of tea instead."
# }

# embedding_question = lambda_handler(json_event)

In [80]:
# cos = torch.nn.CosineSimilarity(dim=0, eps=1e-08)
# similarity_score = cos(torch.Tensor(embedding_question[0][0]), torch.Tensor(embedding_user_answer[0][0]))

In [81]:
# similarity_score

In [82]:
# json_event =  {
#    "inputs": "Do you want a cup of coffee?"
# }

# embedding_user_answer = lambda_handler(json_event)

In [83]:
# json_event =  {
#    "inputs": "the weather is bad"
# }

# embedding_question = lambda_handler(json_event)

In [84]:
# cos = torch.nn.CosineSimilarity(dim=0, eps=1e-08)
# similarity_score = cos(torch.Tensor(embedding_question[0][0]), torch.Tensor(embedding_user_answer[0][0]))
# similarity_score

In [97]:
def is_on_topic(user_answer, question):
    #turn response and question into embedding
    user_answer_embedding = lambda_handler(user_answer)
    question_embedding = lambda_handler(question)
    # calculating the cos similarity
    similarity_score = 1- distance.cosine(user_answer_embedding[0][0], question_embedding[0][0])
    # set the cutoff threshold
    if similarity_score > 0.75:
        return True
    return False


#example of calling the function

In [98]:
user_answer =  {
   "inputs": "Do you want a cup of coffee?"
}

question =  {
   "inputs": "the weather is bad"
}

is_on_topic(user_answer,question)

False

In [99]:
user_answer =  {
   "inputs": " ¡Hola! Bienvenido a la cafetería Brew Haven. ¿Qué quieres? "
}

question =  {
   "inputs": "Hola."
}

is_on_topic(user_answer,question)

False

In [100]:
user_answer =  {
   "inputs": " ¡Hola! Bienvenido a la cafetería Brew Haven. ¿Qué quieres? "
}

question =  {
   "inputs": "Un café, por favor."
}

is_on_topic(user_answer,question)

True

In [58]:
# json_event =  {
#    "inputs": "Welcome to Cafe Strada. What would you like to order today?"
# }

# embedding_user_answer = lambda_handler(json_event)

In [59]:
# json_event =  {
#    "inputs": "planet earth is 8 billion years old"
# }

# embedding_question = lambda_handler(json_event)

In [60]:
# embedding_user_answer = torch.Tensor(embedding_user_answer)
# embedding_question = torch.Tensor(embedding_question)

In [61]:
# cos = torch.nn.CosineSimilarity(dim=0, eps=1e-08)
# similarity_score = cos(torch.Tensor(embedding_question[0][0]), torch.Tensor(embedding_user_answer[0][0]))
# similarity_score

tensor(0.5651)

In [None]:
# scores÷

In [30]:
# def cosine_distance(x1, x2=None, eps=1e-8):
#     x2 = x1 if x2 is None else x2
#     w1 = x1.norm(p=2, dim=1, keepdim=True)
#     w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
#     return 1 - torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)

In [37]:
# cosine_distance(torch.Tensor(embedding_user_answer[0]),torch.Tensor(embedding_question[0]))

tensor([[0.3432, 0.9434, 0.9913, 0.9330, 1.0292, 1.0163, 0.9976, 1.0278, 1.0101,
         0.9960],
        [1.1297, 0.7703, 0.6167, 0.9172, 0.8116, 0.8216, 1.0984, 1.0448, 1.0955,
         0.7674],
        [1.0663, 0.4952, 0.9067, 0.8126, 0.9818, 0.9179, 1.0873, 1.0781, 1.1205,
         0.8165],
        [1.0197, 0.8418, 0.7360, 0.9279, 0.8099, 0.6818, 1.0308, 1.0142, 0.9350,
         0.8931],
        [1.0917, 0.8729, 0.9362, 0.8956, 0.9882, 0.9646, 1.0218, 1.0129, 1.0678,
         0.8121],
        [1.0469, 1.0074, 0.9194, 1.0269, 0.9928, 0.9111, 0.8760, 1.0229, 0.9730,
         0.9099],
        [1.1047, 0.9373, 0.9124, 0.9553, 0.9765, 0.9312, 1.0226, 1.0064, 1.0601,
         0.7865],
        [1.0797, 1.0149, 1.0014, 1.0044, 1.0659, 0.9755, 0.9368, 0.9805, 0.9897,
         0.9404],
        [1.1047, 0.8199, 0.9624, 0.8284, 1.0355, 1.0155, 1.1058, 0.9998, 1.0881,
         0.7319],
        [1.1032, 0.8173, 0.9289, 0.8170, 1.0018, 0.9720, 1.0672, 0.9929, 1.0731,
         0.6946]])