In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import guidance
from textwrap import dedent
import json 

from grammar_guide import guide

In [2]:
model_name_or_path = "HuggingFaceTB/SmolLM-135M"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [3]:
res = guide(
    model,
    tokenizer,
    lark_grammar_filepath="../grammars/json.lark",
    seed_str="""{"name":""",
    prompt=dedent("""
    Here is a really long, nested JSON that extracts fields from this sentence:\n\nMy name is Joseph Smith, and I work at Apple. I'm 32 years old, and my interests include kayaking, skiing, snowboarding, and woodworking.\n\n```json\n
    """),
    draft_model=guidance.models.Transformers(
        model_name_or_path, echo=False
    ),
    max_grammar_corrections=10,
    max_new_tokens=10,
    temperature=0.3,
)
res.process_time_seconds

[33mMade a single_candidate correction...[39m
[33mMade a single_candidate correction...[39m
[33mMade a single_candidate correction...[39m
[33mMade a single_candidate correction...[39m
[33mMade a single_candidate correction...[39m


3.856700897216797

In [4]:
print(json.dumps(json.loads(res.response),indent=4))

{
    "name": "Joseph Smith",
    "age": 32,
    "occupation": "Kayaking",
    "location": "Arizona",
    "location_code": "1234",
    "location_name": "Smith"
}


In [5]:
res = guide(
    model,
    tokenizer,
    lark_grammar_filepath="../grammars/sql.lark",
    seed_str="""{"name":""",
    prompt=dedent("""
    Hello, I am your teacher. Today I will write you a SQL query demonstrating `INNER JOIN` and `LIMIT`.\n\n```sql\n
    """),
    draft_model=guidance.models.Transformers(
        model_name_or_path, echo=False
    ),
    max_grammar_corrections=10,
    max_new_tokens=10,
    temperature=0.3,
)
res.process_time_seconds

[33mMade a single_candidate correction...[39m
[33mMade a draft_gen correction...[39m
No candidates left


2.077803134918213

In [6]:
print(res.response)

SELECT * FROM Student WHERE Name = 'John' AND Age>=2000;

