Skip to content

Commit

Permalink
Fix Sagemaker Batch Endpoints (langchain-ai#3249)
Browse files Browse the repository at this point in the history
Add different typing for @evandiewald 's heplful PR

---------

Co-authored-by: Evan Diewald <evandiewald@gmail.com>
  • Loading branch information
2 people authored and yanghua committed May 9, 2023
1 parent 38e7e89 commit b512fe2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
"\n",
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
"\n",
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
"\n",
"Change from\n",
"\n",
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
"\n",
"to:\n",
"\n",
"`return {\"vectors\": sentence_embeddings.tolist()}`."
]
},
{
Expand All @@ -29,7 +37,7 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict\n",
"from typing import Dict, List\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"import json\n",
Expand All @@ -39,13 +47,13 @@
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
" return input_str.encode('utf-8')\n",
" \n",
" def transform_output(self, output: bytes) -> str:\n",
"\n",
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"embeddings\"]\n",
" return response_json[\"vectors\"]\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",
Expand Down
30 changes: 17 additions & 13 deletions langchain/embeddings/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from langchain.llms.sagemaker_endpoint import ContentHandlerBase


class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
"""Content handler for LLM class."""


class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""Wrapper around custom Sagemaker Inference Endpoints.
Expand Down Expand Up @@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""

content_handler: ContentHandlerBase
content_handler: EmbeddingsContentHandler
"""The content handler class that provides an input and
output transform functions to handle formats between LLM
and the endpoint.
Expand All @@ -71,21 +75,21 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""
Example:
.. code-block:: python
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
class ContentHandler(ContentHandlerBase):
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompts: prompts, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
"""
return response_json["vectors"]
""" # noqa: E501

model_kwargs: Optional[Dict] = None
"""Key word arguments to pass to the model."""
Expand Down Expand Up @@ -135,7 +139,7 @@ def validate_environment(cls, values: Dict) -> Dict:
)
return values

def _embedding_func(self, texts: List[str]) -> List[float]:
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
"""Call out to SageMaker Inference embedding endpoint."""
# replace newlines, which can negatively affect performance.
texts = list(map(lambda x: x.replace("\n", " "), texts))
Expand Down Expand Up @@ -179,7 +183,7 @@ def embed_documents(
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
for i in range(0, len(texts), _chunk_size):
response = self._embedding_func(texts[i : i + _chunk_size])
results.append(response)
results.extend(response)
return results

def embed_query(self, text: str) -> List[float]:
Expand All @@ -191,4 +195,4 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embeddings for the text.
"""
return self._embedding_func([text])
return self._embedding_func([text])[0]
25 changes: 16 additions & 9 deletions langchain/llms/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Wrapper around Sagemaker InvokeEndpoint API."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Union
from abc import abstractmethod
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union

from pydantic import Extra, root_validator

from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens

INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])

class ContentHandlerBase(ABC):

class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
"""A handler class to transform input from LLM to a
format that SageMaker endpoint expects. Similarily,
the class also handles transforming output from the
Expand Down Expand Up @@ -39,22 +42,24 @@ def transform_output(self, output: bytes) -> str:
"""The MIME type of the response data returned from endpoint"""

@abstractmethod
def transform_input(
self, prompt: Union[str, List[str]], model_kwargs: Dict
) -> bytes:
def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
"""Transforms the input to a format that model can accept
as the request Body. Should return bytes or seekable file
like object in the format specified in the content_type
request header.
"""

@abstractmethod
def transform_output(self, output: bytes) -> Any:
def transform_output(self, output: bytes) -> OUTPUT_TYPE:
"""Transforms the output from the model to string that
the LLM class expects.
"""


class LLMContentHandler(ContentHandlerBase[str, str]):
"""Content handler for LLM class."""


class SagemakerEndpoint(LLM):
"""Wrapper around custom Sagemaker Inference Endpoints.
Expand Down Expand Up @@ -110,7 +115,7 @@ class SagemakerEndpoint(LLM):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""

content_handler: ContentHandlerBase
content_handler: LLMContentHandler
"""The content handler class that provides an input and
output transform functions to handle formats between LLM
and the endpoint.
Expand All @@ -120,7 +125,9 @@ class SagemakerEndpoint(LLM):
Example:
.. code-block:: python
class ContentHandler(ContentHandlerBase):
from langchain.llms.sagemaker_endpoint import LLMContentHandler
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
Expand Down

0 comments on commit b512fe2

Please sign in to comment.