In [7]:
import locale

import io

#Def for guessing the encoding of the SQL file

def guess_encoding(file):

    """guess the encoding of the given file"""

    with io.open(file, "rb") as f:

        data = f.read(5)

    if data.startswith(b"\xEF\xBB\xBF"):  # UTF-8 with a "BOM"

        return "utf-8-sig"

    elif data.startswith(b"\xFF\xFE") or data.startswith(b"\xFE\xFF"):

        return "utf-16"

    else:  # guessing utf-8 doesn't work in Windows, so we just give it a try:

        try:

            with io.open(file, encoding="utf-8") as f:

                return "utf-8"

        except:

            return locale.getdefaultlocale()[1]

In [1]:
import textwrap
def create_query_string(sql_file):

    with open(sql_file, 'r', encoding=guess_encoding(sql_file)) as f_in:
        lines = f_in.read()
        query_string = textwrap.dedent("""{}""".format(lines))
        return query_string

In [9]:
meta_data = create_query_string("../data/sample.sql")

In [10]:
query = 'Write a SQL query that fetches all the patients who were prescribed more than 5 different medications on 2023-04-01'

In [21]:
prompt_sql_data = f"""

Human: You're provided with a database schema representing any hospital's patient management system.
The system holds records about patients, their prescriptions, doctors, and the medications prescribed.  
Provide just the SQL Code.  Do not add additional information.

Here's the schema: {meta_data}

{query}


Assistant:
"""

In [12]:
from langchain.llms.bedrock import Bedrock

inference_modifier = {'max_tokens_to_sample':4096, 
                      "temperature":0.5,
                      "top_k":250,
                      "top_p":1,
                      "stop_sequences": ["\n\nHuman"]
                     }

textgen_llm = Bedrock(model_id = "anthropic.claude-v2",
                    model_kwargs = inference_modifier 
                    )

In [22]:
import boto3
import json

bedrock = boto3.client('bedrock-runtime')
model_id = "anthropic.claude-v2"

accept = 'application/json'
contentType = 'application/json'

body = json.dumps({
                    "prompt": prompt_sql_data,
                    "max_tokens_to_sample":4096,
                    "temperature":0.5,
                    "top_k":250,
                    "top_p":0.5,
                    "stop_sequences": ["\n\nHuman:"]
                  })

In [23]:
response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=contentType)
response_body = json.loads(response.get('body').read())

print(response_body.get('completion'))

 SELECT p.FirstName, p.LastName
FROM Patients p 
INNER JOIN Prescriptions pre ON p.PatientID = pre.PatientID
INNER JOIN PrescriptionDetails pd ON pre.PrescriptionID = pd.PrescriptionID
WHERE pre.DateIssued = '2023-04-01'
GROUP BY p.PatientID
HAVING COUNT(DISTINCT pd.MedicationID) > 5;


In [20]:
print(response_body.get('completion'))

 Here is a SQL query to fetch patients who were prescribed more than 5 different medications on 2023-04-01:

```sql
SELECT p.FirstName, p.LastName
FROM Patients p
JOIN Prescriptions pre ON p.PatientID = pre.PatientID
JOIN PrescriptionDetails pd ON pre.PrescriptionID = pd.PrescriptionID
WHERE pre.DateIssued = '2023-04-01'
GROUP BY p.PatientID
HAVING COUNT(DISTINCT pd.MedicationID) > 5
```

The key steps are:

1. Join the Patients, Prescriptions and PrescriptionDetails tables to connect patients with their prescriptions and medication details. 

2. Filter to only prescriptions issued on 2023-04-01.

3. Group by PatientID and count the distinct MedicationIDs per patient. 

4. Use HAVING clause to only keep patients with more than 5 distinct medications.

5. Select the patient FirstName and LastName in the final output.

This will give you all patients who were prescribed more than 5 different medications on the specified date.


In [1]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase

In [1]:
!pip install langchain_experimental

[0mCollecting langchain_experimental
  Obtaining dependency information for langchain_experimental from https://files.pythonhosted.org/packages/4d/6a/20f8ee4849a3d1cbc707a2f0cc504711932c7b4689498a7f9bf37cb0de95/langchain_experimental-0.0.27-py3-none-any.whl.metadata
  Downloading langchain_experimental-0.0.27-py3-none-any.whl.metadata (1.8 kB)
Collecting langchain>=0.0.308 (from langchain_experimental)
  Obtaining dependency information for langchain>=0.0.308 from https://files.pythonhosted.org/packages/1f/46/d82192ebc8d1f0e42b03b5c8a078737fba9fe8ec416722f5865ed9424d49/langchain-0.0.311-py3-none-any.whl.metadata
  Downloading langchain-0.0.311-py3-none-any.whl.metadata (15 kB)
Collecting langsmith<0.1.0,>=0.0.43 (from langchain>=0.0.308->langchain_experimental)
  Obtaining dependency information for langsmith<0.1.0,>=0.0.43 from https://files.pythonhosted.org/packages/9b/6c/fd466f647634ef4a668ea109bf0892d7f78882ffe09500429081ed6dae4a/langsmith-0.0.43-py3-none-any.whl.metadata
  Downlo

In [3]:
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain

db = SQLDatabase.from_uri("sqlite:///../data/Chinook.db")


In [4]:
from langchain.llms.bedrock import Bedrock

inference_modifier = {'max_tokens_to_sample':4096, 
                      "temperature":0.5,
                      "top_k":250,
                      "top_p":1,
                      "stop_sequences": ["\n\nHuman"]
                     }

textgen_llm = Bedrock(model_id = "anthropic.claude-v2",
                    model_kwargs = inference_modifier 
                    )

db_chain = SQLDatabaseChain.from_llm(textgen_llm, db, verbose=True)

In [6]:
db_chain.run("How many employees are there?")



[1m> Entering new SQLDatabaseChain chain...[0m
How many employees are there?
SQLQuery:[32;1m[1;3mHere is the SQL query and result to find the number of employees:

SQLQuery:
SELECT COUNT(*) AS "NumEmployees" 
FROM "Employee"[0m

OperationalError: (sqlite3.OperationalError) near "Here": syntax error
[SQL: Here is the SQL query and result to find the number of employees:

SQLQuery:
SELECT COUNT(*) AS "NumEmployees" 
FROM "Employee"]
(Background on this error at: https://sqlalche.me/e/20/e3q8)