In [2]:
import torch
from syncode import SyncodeLogitsProcessor
from syncode import Grammar
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

HF_CACHE = os.environ['HF_CACHE'] if 'HF_CACHE' in os.environ else 'cache/'
HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None

device = 'cuda'
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model_name = "meta-llama/Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# grammar_str = "python"
# grammar_str = "go"
grammar_str = "java"

grammar = Grammar(grammar_str)
syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)

prompt = f"Write a {grammar_str} function that prints 'hello world' in reverse."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
                  messages, tokenize=False, add_generation_prompt=True
            )
print("[PROMPT]", prompt, "\n")

syncode_logits_processor.reset(prompt)

inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)

attention_mask = torch.ones_like(inputs)
output = model.generate(
      inputs,
      attention_mask=attention_mask,
      max_length=512, 
      num_return_sequences=1, 
      pad_token_id=tokenizer.eos_token_id, 
      logits_processor=[syncode_logits_processor]
      )
output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)
print("[OUTPUT]", output_str)

[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Write a java function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

 

[OUTPUT] public class HelloWorld {
    public static void main(String[] args) {
        System.out.println("Hello World");
    } 

    public static void printReverse(String str) {
        char[] arr = str.toCharArray();
        int start = 0;
        int end = arr.length - 1;

        while (start < end) {
            System.out.print(arr[start]);
            System.out.print(arr[end]);
            start++;
            end--;
        } 
        System.out.println();
    } 
}
