In [1]:
import numpy as np
import openai
from string import Template

openai.api_key_path = '/home/tim/projects/openai/apikey.txt'

## Exploring two different kinds of natural language -> query language translation

* NL -> SQL
* NL -> filter description
  * Filter descriptions are tuples of (column, operator, value)
  
### A lot of this is inspired by the cool stuff in the OpenAI cookbook:
https://github.com/openai/openai-cookbook

In [2]:
# Which completion model to use?  Will the code model be better 
# at this than the bigger-but-more-general text model?

MODELS = {
    "code":'',
    "text1": "text-davinci-003"
}

## These strings are prompt templates that describe a toy database table to be queried

In [3]:
# Look at me prompt engineering.

sql_prompt = Template("""
    Answer the question as truthfully as possible, and if you're unsure of the answer, say "Sorry, I don't know".

Assume we have a sql table with this schema: 

table cell_metadata (
  cell_id: int not null,
  age: float not null,
  sex: varchar(8) not null
  roi: varchar(8) not null,
  data: jsonb
);

The possible values for the column 'sex' are 'M' or 'F' or 'UNKNOWN'.
'M' is an abbreviation for Male, 'F'' for female.

roi stands for Region Of Interest, which is sometimes understood to be a 'structure' or 'area'. Common 
structures are well known in brain anatomy, for example striatum or cortex or hypothalamus.

An age greater than or equal to 56 is considered adult.  Ages below 56 and above 14 are considered juvenile.  
Ages below 14 are pups.

Translate the question below into a sql query on the cell_metadata table

Q: $q
    """)

cell_filter_prompt = Template("""
Answer the question as truthfully as possible, and if you're unsure of the answer, say "Sorry, I don't know".

Assume we have a sql table with this schema: 

table cell_metadata (
  cell_id: int not null,
  age: float not null,
  sex: varchar(8) not null
  roi: varchar(8) not null,
  data: jsonb
);

The possible values for the column 'sex' are 'M' or 'F' or 'UNKNOWN'. 
'M' is an abbreviation for Male, 'F'' for female.

roi stands for Region Of Interest, which is sometimes understood to be a 'structure' or 'area'.  Common 
structures are well known in brain anatomy, for example striatum or cortex or hypothalamus.

An age greater than or equal to 56 is considered adult.  Ages below 56 and above 14 are considered juvenile.  
Ages below 14 are pups.

We can filter this table by specifying the column name, a comparison operator, and the value to compare it to, 
returning all results that match the filter.  Each WHERE clause in this filter can be specified as a JSON
object, like so:

{"column": column_name, "op": comparison_operator, "value": value_to_filter_by}

for example, the sql query "select * from cell_metadata where age >= 56" could be represented as 
{"column": "age", "op": ">=", "value": 56}

We can combine these JSON-encoded WHERE clauses in an array, for example, the sql query
"select * from cell_metadata where age >= 56 and sex = 'M'" would be translated as:
[
    {"column": "age", "op": ">=", "value": 56},
    {"column": "sex", "op": "=", "value": 'M'}
]

Translate the question below on the cell_metadata table. 

Q: $q
""")

general_question = Template("""
Answer the question as truthfully as possible, and if you're unsure of the answer, say "Sorry, I don't know".

Q: $q
""")
                   

free = Template("""
$q
""")

PROMPTS = {
    "cells_sql": sql_prompt,
    "cells_filter": cell_filter_prompt,
    "general": general_question,
    "free": free
}

In [4]:
# Given a prompt and a question, return a translation from NL.
# The key parameter that we're varying in his notebook is translation_type, which
# selects for one of the prompts designed above.

def translate_question(q, translation_type, completion_model='text1', temp=0, max_tokens=300):
    
    prompt = PROMPTS[translation_type]
    
    r = openai.Completion.create(
        prompt=prompt.substitute({'q': q}),
        temperature=temp,
        max_tokens=max_tokens,
        model=MODELS[completion_model]
    )["choices"][0]["text"].strip(" \n")       
    r = r.replace('A:', '').strip()
    
    return r

## That's surprisingly little work to make this happen

In [5]:
r = translate_question("Show me the data for just the adult cells, only for the roi foo", 'cells_sql')
print(r)

SELECT data FROM cell_metadata WHERE age >= 56 AND roi = 'foo' AND sex != 'UNKNOWN';


In [6]:
r = translate_question("Show me the data for just the adult cells, only for the roi foo", 'cells_filter')
print(r)

[
    {"column": "age", "op": ">=", "value": 56},
    {"column": "roi", "op": "=", "value": 'foo'}
]


In [7]:
translate_question('Show me the data for just the adult cells, if the roi includes the term CTX', 'cells_filter')

'[\n    {"column": "age", "op": ">=", "value": 56},\n    {"column": "roi", "op": "LIKE", "value": \'%CTX%\'}\n]'

In [8]:
translate_question("what's the data for just the adult cells, only for the roi 'foo'", 'cells_sql')

"SELECT data FROM cell_metadata WHERE age >= 56 AND roi = 'foo';"

In [9]:
translate_question("what's the data for adult cells, or for juveile cells where the age > 32", 'cells_sql')

"SELECT data \nFROM cell_metadata \nWHERE (age >= 56 AND sex = 'M' OR sex = 'F') \nOR (age > 32 AND age < 56 AND sex = 'M' OR sex = 'F');"

## But people don't usually cooperate in asking well-formed questions:

In [10]:
translate_question('Just in the cortex', 'cells_filter')

'[{"column": "roi", "op": "=", "value": "cortex"}]'

In [11]:
translate_question('How about young mice?', 'cells_filter')

"Sorry, I don't know."

In [12]:
translate_question('Is there any difference between male & female?', 'cells_filter')

"Sorry, I don't know."

In [13]:
translate_question('What do you know about the striatum?', 'cells_sql')

"SELECT * FROM cell_metadata WHERE roi = 'striatum';"

In [14]:
# This same question might be more broadly directed
translate_question('What do you know about the striatum?', 'general')

'The striatum is a part of the brain that is involved in reward, motivation, and motor control. It is composed of the caudate nucleus, the putamen, and the nucleus accumbens.'

## The column/operator/value format has limitations that the SQL version does not, such as answering aggregate questions

In [15]:
translate_question('What percentage of the cells are from males?', 'cells_filter')

"Sorry, I don't know."

In [16]:
# The SQL translation of the same question seems possible... Is that valid SQL?  
# It seems as if the intent has been understood...
translate_question('What percentage of the cells are from males?', 'cells_sql')

"SELECT COUNT(*) * 100.0 / (SELECT COUNT(*) FROM cell_metadata)\nFROM cell_metadata\nWHERE sex = 'M';"

In [17]:
# Interesting; it responds with two separate queries for the two seprate questions
translate_question('How many males are in the data set?  How many females?', 'cells_sql')

"SELECT COUNT(*) FROM cell_metadata WHERE sex = 'M';\nSELECT COUNT(*) FROM cell_metadata WHERE sex = 'F';"

### There's a whole new world of sql injection attacks opening up

## This might be an instance where fine-tuning a model makes sense.  See:
https://beta.openai.com/docs/guides/fine-tuning

* provide 200+ examples of queries + answers, the more the better.
* these should be easy to come up with/generate.
* the goal of the examples would be to describe the kinds of queries and the domain of the tables(s).
* perhaps the kinds of clarifying questions to respond with

## Adding a general response if the question doesn't fit the table

In [18]:
# Fascinating!  The model assumes that the 'cell type' is a field in the JSONB, and forms an appropriate query!
q = 'What are the canonical cell types in the brain?'
r = translate_question(q, 'cells_sql')
if r == "Sorry, I don't know.":
    r = translate_question(q, 'general')
    
print(r)

SELECT DISTINCT data->'cell_type' FROM cell_metadata;


In [19]:

q = 'Is the hippocampus involved in memory formation?'
r = translate_question(q, 'cells_sql')
if r == "Sorry, I don't know.":
    r = translate_question(q, 'general')
    
print(r)

Yes, the hippocampus is involved in memory formation.


In [20]:
q = 'What is the mechanism of memory formation in the hippocampus?'
r = translate_question(q, 'cells_sql')
if r == "Sorry, I don't know.":
    r = translate_question(q, 'general')
    
print(r)

The mechanism of memory formation in the hippocampus is believed to involve the strengthening of synaptic connections between neurons through a process known as long-term potentiation (LTP).


In [21]:
q = 'Well then you better show me data for cells in the hippocampus too.'
r = translate_question(q, 'cells_sql')
if r == "Sorry, I don't know.":
    r = translate_question(q, 'general')
    
print(r)

SELECT * FROM cell_metadata WHERE roi = 'hippocampus';


In [22]:
# It should know this!  Add an extra free-form query for thiings that can't be answered with confidence (lie to me)
q = 'What is the airspeed of a European swallow?'
r = translate_question(q, 'cells_sql')
if r == "Sorry, I don't know.":
    r = translate_question(q, 'general')
if r == "Sorry, I don't know.":
    r = translate_question(q, 'free') + (" (Beware, this might be made up!)")
        
print(r)

The airspeed of a European swallow is approximately 11 meters per second (24 miles per hour). (Beware, this might be made up!)
