In [1]:
import os
device = 'cuda'
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set your cuda device
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
import ctrlg
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

# load the pretrained base_model and hmm_model; see README.md for a complete list of 
# released checkpoints. note that the hmm_model and base_model must share the same 
# vocabulary of tokens: i.e., one cannot apply hmm_gpt2-large_common-gen_4096 to 
# tulu2-7b_writing-prompts. To apply Ctrl-G to a custom base_model or to achieve 
# best performance on a specific domain, users would need to distill an hmm_model
# from the base_model. Please refer to tutorial_distillation.ipynb for details.
BASE_MODEL_PATH = f'ctrlg/gpt2-large_common-gen' # a gpt2-large checkpoint domain adapted to the common-gen corpus
HMM_MODEL_PATH = f'ctrlg/hmm_gpt2-large_common-gen_4096' # alternatively 'ctrlg/hmm_gpt2-large_common-gen_32768' for better quality

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH).to(device)
base_model.eval()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
hmm_model = ctrlg.HMM.from_pretrained(HMM_MODEL_PATH).to(device)

In [2]:
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id


##################################### prefix, suffix, prompt #####################################
prefix = '' # generate text starting with nothing
suffix = '<|endoftext|>' # generate text ending with '<|endoftext|>'; a suffix must end with the eos token
prompt = '<|endoftext|>' # prompt the base model with the '<|endoftext|>' token

prefix_ids = tokenizer.encode(prefix)
suffix_ids = tokenizer.encode(suffix)
prompt_ids = tokenizer.encode(prompt)
##################################### prefix, suffix, prompt #####################################


##################################### DFA Construction #####################################
my_flat_json_builder = ctrlg.myFlatJsonBuilder(tokenizer, vocab_size)

dfa_graphs = []

dfa_graphs.append(my_flat_json_builder.build())

# taking the intersection of the DFAs, i.e., "logical and" of the constraints.
# This function also minimizes the constructed DFA, which is mainly CPU-based operations;
# Due to its pure python implemenation, DFA minimization can be slow for complex constraints
dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')

# compile the dfa_graph for efficient GPU execution
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)
##################################### DFA Construction #####################################


##################################### token length #####################################
# specify the min_new_tokens and max_new_tokens to be generated (excluding
# the prefix and suffix) make sure that the numbers here would not conflict
# with the given constraint: e.g. ask the model to generate 10 words with
# max_new_tokens = 8
min_new_tokens = 5
max_new_tokens = 32
##################################### token length #####################################

In [3]:
# initialze the constraints logits processor
# Note: this part pre-computes & cache certain conditional probability tables;
# one simple optimization is to re-use the same constraint_logits_processor for
# base_model.generate if the constraints do not change.
constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
    hmm_model, dfa_model,
    min_new_tokens, max_new_tokens,
    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


# set beam_size for beam search; usually the larger the beam_size the
# higher the generation quality
beam_size = 128

# set the hmm_batch_size depending on the resource available;
# uses more memory with larger hmm_batch_size but attains best speed 
# when it is set to beam_size
constraint_logits_processor.hmm_batch_size = beam_size

# generate with beam search
input_ids = torch.tensor([prompt_ids], device=device)
outputs = base_model.generate(
        input_ids=input_ids, do_sample=False, length_penalty=0.2,
        num_beams=beam_size, num_return_sequences=beam_size,
        min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
        logits_processor=LogitsProcessorList([constraint_logits_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )

In [4]:
# extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

# rank the generated ids by the base_model probability
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids, length_penalty=0.2)

# print top 10 outputs
want = -1
for idx, generated in enumerate(generated_ids[:want]):
    print(f'{idx}. ' + tokenizer.decode(prefix_ids, skip_special_tokens=True) + \
          '\033[1m' + tokenizer.decode(generated, skip_special_tokens=True) + '\033[0m' + \
          tokenizer.decode(suffix_ids, skip_special_tokens=True))

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


0. [1m { "" : "" }[0m
1. [1m { ": " }[0m
2. [1m { ": "}[0m
3. [1m { " : " }[0m
4. [1m { " : "" }[0m
5. [1m{ ": " }[0m
6. [1m{ " : " }[0m
7. [1m{ ": "}[0m
8. [1m {" : "" }[0m
9. [1m { "" : " }[0m
10. [1m { " :" }[0m
11. [1m {" : " }[0m
12. [1m { ":..." }[0m
13. [1m { ":" }[0m
14. [1m { ":=" }[0m
15. [1m { ":-" }[0m
16. [1m { ":?" }[0m
17. [1m { ":":" }[0m
18. [1m { ":}" }[0m
19. [1m {": "}[0m
20. [1m { ":"; }[0m
21. [1m { ":"> }[0m
22. [1m { " : "+ }[0m
23. [1m { ": "" }[0m
24. [1m { "_: " }[0m
25. [1m { "::" }[0m
26. [1m {" : "}[0m
27. [1m { "$ : " }[0m
28. [1m { "/ : " }[0m
29. [1m { "$:?" }[0m
30. [1m { " : "}[0m
31. [1m { " : "# }[0m
32. [1m { " : "/ }[0m
33. [1m { "_: "}[0m
34. [1m { "# : " }[0m
35. [1m { " : "+}[0m
36. [1m { ":"," }[0m
37. [1m { ": \" }[0m
38. [1m { ":," }[0m
39. [1m { "$:=" }[0m
40. [1m { " : ". }[0m
41. [1m { ":\" }[0m
42. [1m { ":", }[0m
43. [1m { ":","}[0m
44. [1m { ":{" }