In [1]:
from string import Template
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import guidance
from textwrap import dedent
import grammar_guide as gg
from transformers import set_seed

In [2]:
STOP_STRING_LIST = ["```", "}"]

def load_model(model_name_or_path: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path, device_map="cuda" if torch.cuda.is_available() else None
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    model.generation_config.pad_token_id = model.generation_config.eos_token_id
    return (model, tokenizer)

In [3]:
num_json_keys = 10

prompt = dedent(
    f"""
        This is an introduction to a prompt. It is intended to mimick the lengthy few-shot prompts we tend to use.
        Anyways, now I will get to my real point.
        Here is a JSON object, with {num_json_keys} keys, using only string values:\n\n```json\n
        """
)
lark_grammar_str = Template(
    open("./benchmarks/json.lark").read()
)
lark_grammar_str = lark_grammar_str.safe_substitute(
    NUM_REPEATS=f"{num_json_keys - 1}"
)

model_name_or_path = "HuggingFaceTB/SmolLM-135M"
model, tokenizer = load_model(model_name_or_path=model_name_or_path)

In [4]:
set_seed(42)
res = gg.guide(
    model,
    tokenizer=tokenizer,
    parser=gg.load_parser(lark_grammar_str),
    prompt=prompt,
    draft_model=guidance.models.Transformers(model_name_or_path, echo=False),
    stop_at=STOP_STRING_LIST,
    max_grammar_corrections=10,
    max_new_tokens=15,
    temperature=0.0,
    token_healing=True,
    verbose=True,
    debug=False,
)

Forward pass:

        This is an introduction to a prompt. It is intended to mimick the lengthy few-shot prompts we tend to use.
        Anyways, now I will get to my real point.
        Here is a JSON object, with 10 keys, using only string values:

```json


[33mMade a single_candidate correction...[39m


Previous kv cache size: 78
New size: 75
Previous kv cache size: 75
New size: 75


[33mMade a single_candidate correction...[39m


Previous kv cache size: 92
New size: 82
Forward pass:
 
Previous kv cache size: 83
New size: 83


[33mMade a single_candidate correction...[39m


Previous kv cache size: 100
New size: 99


AssertionError: 

In [8]:
set_seed(42)
res = gg.guide(
    model,
    tokenizer=tokenizer,
    parser=gg.load_parser(lark_grammar_str),
    prompt=prompt,
    draft_model=guidance.models.Transformers(model_name_or_path, echo=False),
    stop_at=STOP_STRING_LIST,
    max_grammar_corrections=10,
    max_new_tokens=15,
    temperature=0.0,
    token_healing=False,
    verbose=True,
    debug=False,
)

Forward pass:

        This is an introduction to a prompt. It is intended to mimick the lengthy few-shot prompts we tend to use.
        Anyways, now I will get to my real point.
        Here is a JSON object, with 10 keys, using only string values:

```json


[33mMade a single_candidate correction...[39m


Previous kv cache size: 78
New size: 75


[33mMade a draft_gen correction...[39m
[33mMade a single_candidate correction...[39m


Previous kv cache size: 107
New size: 96


[33mMade a single_candidate correction...[39m


Previous kv cache size: 102
New size: 99


[33mMade a single_candidate correction...[39m
[33mMade a single_candidate correction...[39m
[33mMade a single_candidate correction...[39m


Previous kv cache size: 103
New size: 100
Previous kv cache size: 103
New size: 104


[33mMade a single_candidate correction...[39m
[33mMade a single_candidate correction...[39m


Previous kv cache size: 121
New size: 123


[33mMade a single_candidate correction...[39m
[31mCannot find a valid prediction after 10 retries[39m


In [25]:
print(res.response)

{
 "name": "John",
 "age": "20",
 "city": "New York",
 "email": "<","
 ":"
 ",":
 ":"
 "
