In [9]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import guidance
from textwrap import dedent
import json 
from openai import OpenAI
import os 
from dotenv import load_dotenv

import grammar_guide as gg

load_dotenv()

True

In [10]:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

In [11]:
model_name_or_path = "HuggingFaceTB/SmolLM-135M"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
parser = gg.load_parser(lark_grammar_filepath="../grammars/json.lark")

In [12]:
def openai_generate(s: str) -> str:
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "assistant",
                "content": s,
            }
        ],
        model="gpt-3.5-turbo",
    )
    return chat_completion.choices[0].message.content

In [14]:
prompt = dedent("""
Here is a really long, nested JSON that extracts fields from this sentence:\n\nMy name is Joseph Smith, and I work at Apple. I'm 32 years old, and my interests include kayaking, skiing, snowboarding, and woodworking.\n\n```json\n
""")
res = gg.guide(
    model=openai_generate,
    parser=gg.load_parser(lark_grammar_filepath="../grammars/sql.lark"),
    prompt="Here's a long, complex SQL function:",
    draft_model=guidance.models.Transformers(
        model_name_or_path, echo=False
    ),
    max_grammar_corrections=20,
    verbose=True,
)
print(res.process_time_seconds)
print(len(tokenizer(res.response)['input_ids']) / res.process_time_seconds)
print(res.response)

SELECT 
DISPENSE_STATUS,
SUM(CASE WHEN DISPENSE_STATUS = 'completed' THEN 1 ELSE 0 END) AS completed_count,
SUM(CASE WHEN DISPENSE_STATUS = 'in progress' THEN 1 ELSE 0 END) AS in_progress_count,
SUM(CASE WHEN DISPENSE_STATUS = 'pending' THEN 1 ELSE 0 END) AS pending_count
FROM DISPENSE
GROUP BY DISPENSE_STATUS;
1.8693230152130127
63.65940986739509
SELECT 
DISPENSE_STATUS,
SUM(CASE WHEN DISPENSE_STATUS = 'completed' THEN 1 ELSE 0 END) AS completed_count,
SUM(CASE WHEN DISPENSE_STATUS = 'in progress' THEN 1 ELSE 0 END) AS in_progress_count,
SUM(CASE WHEN DISPENSE_STATUS = 'pending' THEN 1 ELSE 0 END) AS pending_count
FROM DISPENSE
GROUP BY DISPENSE_STATUS;


In [15]:
res.correction_log

[]

In [17]:
res = guide(
    model,
    tokenizer=tokenizer,
    parser=load_parser(lark_grammar_filepath="../grammars/sql.lark"),
        prompt="Here's a long, complex SQL function:",
    draft_model=guidance.models.Transformers(
        model_name_or_path, echo=False
    ),
    stop_at=['```', ';'],
    max_grammar_corrections=20,
    verbose=True,
    max_new_tokens=20,
    temperature=0.0,
)
print(res.process_time_seconds)
print(len(tokenizer(res.response)['input_ids']) / res.process_time_seconds)

[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a draft_gen correction...[39m
[33mMade a single_candidate correction...[39m


29.74561882019043
8.942493400723848


In [18]:
try:
    print(json.dumps(json.loads(res.response), indent = 4))
except:
    print(res.response)

SELECT * FROM table WHERE column1 = 'a' AND column2 = 'b' AND column3 =('c'*2) AND column4 = 'd' AND column5 = 'e' AND column6 =('f'*2) AND column7 = 'g' AND column8 = 'h' AND(column1 = 'a' AND column2 = 'b' AND column3 = 'c'AND column4 = 'd' AND column5 = 'e' AND column6 = 'f' AND(column7 = 'g' AND column8 = 'h' AND column1 = 'a'AND column2 = 'b' AND column3 = 'c' AND column4 = 'd' AND(column7 = 'g' AND column8 = 'h' AND column1 = ('a' * 2)) AND column9 = 'i' AND column10 = 'j' AND column= 'k' AND column= 'l' AND column= 'm' AND column= ('' * 2)) AND column11 = 'n' AND column12 = ('' * 2)) AND column12 = 'o' AND column1 =  (('' * 2)+' AND column1 = ');
