In [1]:
from syncode import SyncodeLogitsProcessor, Grammar
from transformers import AutoModelForCausalLM, AutoTokenizer
import lark

grammar_str = r"""
start: item ("," item)* 

item: "'" name "'"
    | "\"" name "\""

name: "Alice" 
    | "Bob" 
    | "Carol" 
    | "Dave"
    | "Eve"
"""

device = "cuda"
model_name = "deepseek-ai/deepseek-coder-1.3b-base"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto").eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)

syncode_grammar = Grammar(grammar_str)
parser = lark.Lark(grammar_str)

prompt = "A list of male first names:\n"

inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

constrain = True

args = {
    "max_new_tokens" : 128,
    "do_sample" : True,
    "num_beams" : 2,
    "num_return_sequences" : 2,
    "pad_token_id" : tokenizer.eos_token_id,
}

syncode_logits_processor = SyncodeLogitsProcessor(
    grammar=syncode_grammar, 
    tokenizer=tokenizer, 
    parse_output_only=True, 
    num_samples=args["num_beams"],
    mode="grammar_strict",
)

  from .autonotebook import tqdm as notebook_tqdm
Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}


In [9]:
syncode_logits_processor.reset(prompt)

outputs = model.generate(
    inputs,
    logits_processor=[syncode_logits_processor] if constrain else None,
    **args,
)

outputs = [o[len(inputs[0]):] for o in outputs]
completions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
completions_tokens = [[tokenizer.decode(tok) for tok in output] for output in outputs]
print(completions)


for i, (c, toks) in enumerate(zip(completions, completions_tokens)):
    print(f">>> COMPLETION {i}\n")
    try:
        tree = parser.parse(c)
        print("CAN PARSE\n")
    except:
        print("CANNOT PARSE\n") 
    print(f"{c}\n")
    print(f"{toks}\n")

["'Alice'", '"Alice"']
>>> COMPLETION 0

CAN PARSE

'Alice'

["'", 'A', 'lic', 'e', "'", '<｜end▁of▁sentence｜>']

>>> COMPLETION 1

CAN PARSE

"Alice"

['"', 'A', 'lic', 'e', '"', '<｜end▁of▁sentence｜>']

