# Text Generation: How to run inference on the endpoint you have created?

In [297]:
import json
import boto3

In [12]:
!apt-get update -qq && apt-get install -y build-essential -qq

debconf: delaying package configuration, since apt-utils is not installed
Selecting previously unselected package liblocale-gettext-perl.
(Reading database ... 21267 files and directories currently installed.)
Preparing to unpack .../liblocale-gettext-perl_1.07-3+b4_amd64.deb ...
Unpacking liblocale-gettext-perl (1.07-3+b4) ...
Preparing to unpack .../gpgv_2.2.12-1+deb10u2_amd64.deb ...
Unpacking gpgv (2.2.12-1+deb10u2) over (2.2.12-1+deb10u1) ...
Setting up gpgv (2.2.12-1+deb10u2) ...
Selecting previously unselected package libstdc++-8-dev:amd64.
(Reading database ... 21282 files and directories currently installed.)
Preparing to unpack .../00-libstdc++-8-dev_8.3.0-6_amd64.deb ...
Unpacking libstdc++-8-dev:amd64 (8.3.0-6) ...
Selecting previously unselected package g++-8.
Preparing to unpack .../01-g++-8_8.3.0-6_amd64.deb ...
Unpacking g++-8 (8.3.0-6) ...
Selecting previously unselected package g++.
Preparing to unpack .../02-g++_4%3a8.3.0-1_amd64.deb ...
Unpacking g++ (4:8.3.0-1) ...

In [13]:
# Optional: update pip
%pip install pip -Uq

# Install latest versions of other required packages
%pip install ipywidgets langchain awscli boto3 python-dotenv SQLAlchemy psycopg2-binary chromadb -Uq
%pip install pyyaml -q

# Avoid issues with install
# https://github.com/aws/amazon-sagemaker-examples/issues/1890#issuecomment-758871546
%pip install sentence-transformers -Uq --no-cache-dir #--force-reinstall

[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
spyder 5.1.5 requires py

In [14]:
%pip list | grep 'langchain\|sentence-transformers\|SQLAlchemy'

langchain                                0.1.12
langchain-community                      0.0.28
langchain-core                           0.1.32
langchain-text-splitters                 0.0.1
sentence-transformers                    2.5.1
SQLAlchemy                               2.0.28
Note: you may need to restart the kernel to use updated packages.


In [15]:
REGION_NAME = "us-east-1"

In [435]:
import pandas as pd
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///moma.db")
print(db.dialect)
print(db.get_usable_table_names())

sqlite
['artists', 'artworks']


### Query endpoint that you have created
***
The following cell provides a helper function that will be used to query your endpoint with boto3.
***

In [3]:
endpoint_name = 'jumpstart-dft-meta-textgeneration-llama-codellama-7b'

def query_endpoint(payload):
    client = boto3.client('runtime.sagemaker')
    response = client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps(payload).encode('utf-8'),
    )
    response = response["Body"].read().decode("utf8")
    response = json.loads(response)
    return response


In [429]:
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        #input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        
        #input_str = json.dumps({"inputs": prompt, **model_kwargs})
        input_str = json.dumps({"inputs": prompt,  "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
       # return response_json["generated_texts"][0]
        #print(response_json)
        #print(response_json[0]['generated_text']['SQLQuery'])
        #print(response_json[0]['generated_text'])
        #print("First print statement")
        print(response_json[0]['generated_text'])
        #print("Second print statement")
        return response_json[0]["generated_text"]


content_handler = ContentHandler()

parameters = {
    "max_length": 50,
    "temperature": 0.01,"early_stopping":True,
    "top_p":0.8,
    "top_k":1,
    "stop":["\nQuestion:","\nSQLResult:","\n\nQuestion:","\nSQLQuery:"]
}

# smaller model endpoint
llm_sml = SagemakerEndpoint(
    endpoint_name=endpoint_name,
    region_name=REGION_NAME,
    model_kwargs=parameters,
    
    content_handler=content_handler,
)

# larger model endpoint
llm_lrg = SagemakerEndpoint(
    endpoint_name=endpoint_name,
    region_name=REGION_NAME,
    model_kwargs=parameters,
    content_handler=content_handler,
)

In [430]:
#!pip install langchain-experimental

In [431]:
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import  SQLDatabaseChain
from langchain_experimental.sql import SQLDatabaseSequentialChain

In [432]:
custom_table_info = {
    "artists": """CREATE TABLE artists (
        artist_id integer NOT NULL,
        name character varying(200),
        nationality character varying(50),
        gender character varying(25),
        birth_year integer,
        death_year integer,
        CONSTRAINT artists_pk PRIMARY KEY (artist_id))

/*
3 rows from artists table:
"artist_id"	"name"	"nationality"	"gender"	"birth_year"	"death_year"
12	"Jüri Arrak"	"Estonian"	"Male"	1936	
19	"Richard Artschwager"	"American"	"Male"	1923	2013
22	"Isidora Aschheim"	"Israeli"	"Female"		
*/""",
    "artworks": """CREATE TABLE artworks (
        artwork_id integer NOT NULL,
        title character varying(500),
        artist_id integer NOT NULL,
        name character varying(500),
        date integer,
        medium character varying(250),
        dimensions text,
        acquisition_date text,
        credit text,
        catalogue character varying(250),
        department character varying(250),
        classification character varying(250),
        object_number text,
        diameter_cm text,
        circumference_cm text,
        height_cm text,
        length_cm text,
        width_cm text,
        depth_cm text,
        weight_kg text,
        durations integer,
        CONSTRAINT artworks_pk PRIMARY KEY (artwork_id))

/*
3 rows from artworks table:
"artwork_id"	"title"	"artist_id"	"name"	"date"	"medium"	"dimensions"	"acquisition_date"	"credit"	"catalogue"	"department"	"classification"	"object_number"	"diameter_cm"	"circumference_cm"	"height_cm"	"length_cm"	"width_cm"	"depth_cm"	"weight_kg"	"durations"
102312	"Watching the Game"	2422	"John Gutmann"	1934	"Gelatin silver print"	"9 3/4 x 6 7/16' (24.8 x 16.4 cm)"	"2006-05-11"	"Purchase"	"N"	"Photography"	"Photograph"	"397.2006"			"24.8"		"16.4"			
103321	"Untitled (page from Sump)"	25520	"Jerome Neuner"	1994	"Page with chromogenic color print and text"	"12 x 9 1/2' (30.5 x 24.1 cm)"	"2006-05-11"	"E.T. Harmax Foundation Fund"	"N"	"Photography"	"Photograph"	"415.2006.12"			"30.4801"		"24.13"			
10	"The Manhattan Transcripts Project, New York, New York, Episode 1: The Park"	7056	"Bernard Tschumi"		"Gelatin silver photograph"	"14 x 18' (35.6 x 45.7 cm)"	"1995-01-17"	"Purchase and partial gift of the architect in honor of Lily Auchincloss"	"Y"	"Architecture & Design"	"Architecture"	"3.1995.11"			"35.6"		"45.7"			
*/""",
}

In [433]:
from langchain.prompts.prompt import PromptTemplate

_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date('now') function to get the current date, if the question involves \"today\".
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"

Only use the following tables:
{table_info}

If someone asks for the art table, they really mean the artworks table.

Only single quotation marks, not double quotation marks in the SQL statement (SQLQuery). Never use " in SQL statement (SQLQuery).

Question: {input}"""
PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)

In [412]:
db.run("select count(*) from artists")

'[(15542,)]'

In [231]:
# A few sample questions
QUESTION_01 = "How many artists are there?"
QUESTION_02 = "How many artworks are there?"
QUESTION_03 = "How many rows are in the artists table?"
QUESTION_04 = "How many rows are in the artworks table?"
QUESTION_05 = "How many artists are there whose nationality is French?"
QUESTION_06 = "How many artworks were created by artists whose nationality is Spanish?"
QUESTION_07 = "How many artist names start with 'M'?"
QUESTION_08 = "What nationality produced the most number of artworks?"
QUESTION_09 = "How many artworks are by Claude Monet?"
QUESTION_10 = "What is the oldest artwork in the collection?"

In [305]:
llm_lrg

SagemakerEndpoint(client=<botocore.client.SageMakerRuntime object at 0x7f89ee7ee910>, endpoint_name='jumpstart-dft-meta-textgeneration-llama-codellama-7b', region_name='us-east-1', content_handler=<__main__.ContentHandler object at 0x7f89eee157f0>, model_kwargs={'max_length': 2000, 'temperature': 0.01, 'early_stopping': True, 'top_p': 0.8, 'stop': ['Question']})

In [428]:
# Reference: https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#sqldatabasesequentialchain
from sqlalchemy.exc import ProgrammingError


db = SQLDatabase.from_uri('sqlite:///moma.db',sample_rows_in_table_info=3,custom_table_info = custom_table_info,include_tables=["artists","artworks"],)

#db_chain = SQLDatabaseSequentialChain.from_llm(
#    llm_lrg, db, verbose=True, use_query_checker=True
#)
db_chain = SQLDatabaseChain.from_llm(
    llm_lrg,
    db,
    prompt=PROMPT,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True,
)

try:
    db_chain.invoke(QUESTION_01)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")



[1m> Entering new SQLDatabaseChain chain...[0m
How many artists are there?
SQLQuery: SELECT COUNT(*) FROM artists
SQLResult:


SELECT COUNT(*) FROM artists





























































































[32;1m[1;3mSELECT COUNT(*) FROM artists[0m
SQLResult: [33;1m[1;3m[(15542,)][0m
Answer: 15542

Question:
[32;1m[1;3m15542

Question:[0m
[1m> Finished chain.[0m


In [436]:
# Reference: https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#sqldatabasesequentialchain
from sqlalchemy.exc import ProgrammingError
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(k = 0, input_key = "input")

db = SQLDatabase.from_uri('sqlite:///moma.db',sample_rows_in_table_info=3,custom_table_info = custom_table_info)

#db_chain = SQLDatabaseSequentialChain.from_llm(
#    llm_lrg, db, verbose=True, use_query_checker=True
#)
db_chain = SQLDatabaseChain.from_llm(
    llm_lrg,
    db,
    prompt=PROMPT,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True,memory = memory,
)

try:
    db_chain(QUESTION_02)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")



[1m> Entering new SQLDatabaseChain chain...[0m
How many artworks are there?
SQLQuery: SELECT COUNT(*) FROM artworks
SQLResult:


SELECT COUNT(*) FROM artworks




























































































[32;1m[1;3mSELECT COUNT(*) FROM artworks[0m
SQLResult: [33;1m[1;3m[(151873,)][0m
Answer: 151873

Question:
[32;1m[1;3m151873

Question:[0m

One output key expected, got dict_keys(['result', 'intermediate_steps'])


In [427]:
# Reference: https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#sqldatabasesequentialchain
from sqlalchemy.exc import ProgrammingError
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(k = 0, input_key = "input")

db = SQLDatabase.from_uri('sqlite:///moma.db',sample_rows_in_table_info=3,custom_table_info = custom_table_info)

#db_chain = SQLDatabaseSequentialChain.from_llm(
#    llm_lrg, db, verbose=True, use_query_checker=True
#)
db_chain = SQLDatabaseChain.from_llm(
    llm_lrg,
    db,
    prompt=PROMPT,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True,memory = memory,
)

try:
    db_chain(QUESTION_05)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")



[1m> Entering new SQLDatabaseChain chain...[0m
How many artists are there whose nationality is French?
SQLQuery: SELECT COUNT(*) FROM artists WHERE nationality = 'French';
SQLResult:


SELECT COUNT(*) FROM artists WHERE nationality = 'French';





















































































[32;1m[1;3mSELECT COUNT(*) FROM artists WHERE nationality = 'French';[0m
SQLResult: [33;1m[1;3m[(860,)][0m
Answer: 860

Question:
[32;1m[1;3m860

Question:[0m

One output key expected, got dict_keys(['result', 'intermediate_steps'])


In [437]:
db.run('SELECT COUNT(*) FROM artists WHERE nationality = \'French\'')

'[(860,)]'

In [441]:
# Reference: https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#sqldatabasesequentialchain
from sqlalchemy.exc import ProgrammingError
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(k = 0, input_key = "input")

db = SQLDatabase.from_uri('sqlite:///moma.db',sample_rows_in_table_info=3,custom_table_info = custom_table_info)

#db_chain = SQLDatabaseSequentialChain.from_llm(
#    llm_lrg, db, verbose=True, use_query_checker=True
#)
db_chain = SQLDatabaseChain.from_llm(
    llm_lrg,
    db,
    prompt=PROMPT,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True,memory = memory,
)

try:
    db_chain(QUESTION_10)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")



[1m> Entering new SQLDatabaseChain chain...[0m
What is the oldest artwork in the collection?
SQLQuery: SELECT title, artist_id, name, date, medium, dimensions, acquisition_date, credit, catalogue, department, classification, object_number, diameter_cm, circumference_cm, height_cm, length_cm, width_cm, depth_cm, weight_kg, durations FROM artworks WHERE date = (SELECT MIN(date) FROM artworks)
SQLResult:


SELECT title, artist_id, name, date, medium, dimensions, acquisition_date, credit, catalogue, department, classification, object_number, diameter_cm, circumference_cm, height_cm, length_cm, width_cm, depth_cm, weight_kg, durations FROM artworks WHERE date = (SELECT MIN(date) FROM artworks)

















[32;1m[1;3mSELECT title, artist_id, name, date, medium, dimensions, acquisition_date, credit, catalogue, department, classification, object_number, diameter_cm, circumference_cm, height_cm, length_cm, width_cm, depth_cm, weight_kg, durations FROM artworks WHERE date = (SELECT MI

OperationalError: (sqlite3.OperationalError) no such column: artist_id
[SQL: SELECT title, artist_id, name, date, medium, dimensions, acquisition_date, credit, catalogue, department, classification, object_number, diameter_cm, circumference_cm, height_cm, length_cm, width_cm, depth_cm, weight_kg, durations FROM artworks WHERE date = (SELECT MIN(date) FROM artworks)]
(Background on this error at: https://sqlalche.me/e/20/e3q8)