## Ctrl-G Tutorial

### **Part A**. Ctrl-G on GPT2-large (less computation required)

**Step 1. load pretrained models**

In [6]:
import os
device = 'cuda'
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set your cuda device
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
import ctrlg
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

# load the pretrained base_model and hmm_model; see README.md for a complete list of 
# released checkpoints. note that the hmm_model and base_model must share the same 
# vocabulary of tokens: i.e., one cannot apply hmm_gpt2-large_common-gen_4096 to 
# tulu2-7b_writing-prompts. To apply Ctrl-G to a custom base_model or to achieve 
# best performance on a specific domain, users would need to distill an hmm_model
# from the base_model. Please refer to tutorial_distillation.ipynb for details.
BASE_MODEL_PATH = f'ctrlg/gpt2-large_common-gen' # a gpt2-large checkpoint domain adapted to the common-gen corpus
HMM_MODEL_PATH = f'ctrlg/hmm_gpt2-large_common-gen_4096' # alternatively 'ctrlg/hmm_gpt2-large_common-gen_32768' for better quality

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH).to(device)
base_model.eval()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
hmm_model = ctrlg.HMM.from_pretrained(HMM_MODEL_PATH).to(device)



**Step 2. specify logical constraints as DFAs (example constraint 1)**

In [2]:
vocab_size = hmm_model.vocab_size
a =ctrlg.populate_edge(["Step"], vocab_size, tokenizer)
[i for i, x in enumerate(a) if x]

[8600]

In [42]:
tokenizer.encode("V5")

[1, 478, 29945]

In [17]:
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id


##################################### prefix, suffix, prompt #####################################
prefix = "The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas:"

#"The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%.Is ringing alarm more likely than silent alarm overall?Guidance: Address the question by following the steps below:\nStep 1) Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas. Step 2) Determine the query type: Identify the type of query implied by the main question. Choices include 'marginal probability', 'conditional probability', 'explaining away effect', 'backdoor adjustment set', 'average treatment effect', 'collider bias', 'normal counterfactual question', 'average treatment effect on treated', 'natural direct effect' or 'natural indirect effect'. Your answer should only be a term from the list above, enclosed in quotation marks. Step 3) Formalize the query: Translate the query into its formal mathematical expression based on its type, utilizing the 'do(·)' notation or counterfactual notations as needed. Step 4) Extract all the available data. Your answer should contain nothing but marginal probabilities and conditional probabilities in the form 'P(...)=...' or 'P(...|...)=...', each probability being separated by a semicolon. Stick to the previously mentioned denotations for the variables. Step 5) Given all the information above, deduce the estimand using skills such as do-calculus, counterfactual prediction, and the basics of probabilities. Answer step by step. Step 6) Insert the relevant data in Step 4 into the estimand, perform basic arithmetic calculations, and derive the final answer. There is an identifiable answer. Answer step by step. \nBased on all the reasoning above, output one word to answer the initial question." # generate text starting with nothing
suffix = '<|endoftext|>' # generate text ending with '<|endoftext|>'; a suffix must end with the eos token
#prompt = "The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%.Is ringing alarm more likely than silent alarm overall?Guidance: Address the question by following the steps below:\nStep 1) Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas. Step 2) Determine the query type: Identify the type of query implied by the main question. Choices include 'marginal probability', 'conditional probability', 'explaining away effect', 'backdoor adjustment set', 'average treatment effect', 'collider bias', 'normal counterfactual question', 'average treatment effect on treated', 'natural direct effect' or 'natural indirect effect'. Your answer should only be a term from the list above, enclosed in quotation marks. Step 3) Formalize the query: Translate the query into its formal mathematical expression based on its type, utilizing the 'do(·)' notation or counterfactual notations as needed. Step 4) Extract all the available data. Your answer should contain nothing but marginal probabilities and conditional probabilities in the form 'P(...)=...' or 'P(...|...)=...', each probability being separated by a semicolon. Stick to the previously mentioned denotations for the variables. Step 5) Given all the information above, deduce the estimand using skills such as do-calculus, counterfactual prediction, and the basics of probabilities. Answer step by step. Step 6) Insert the relevant data in Step 4 into the estimand, perform basic arithmetic calculations, and derive the final answer. There is an identifiable answer. Answer step by step:<|endoftext|>"
prompt = "Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas.<|endoftext|>"
# prompt the base model with the '<|endoftext|>' token

prefix_ids = tokenizer.encode(prefix)
suffix_ids = tokenizer.encode(suffix)
prompt_ids = tokenizer.encode(prompt)
##################################### prefix, suffix, prompt #####################################
##################################### DFA Construction #####################################
# ac_builder constructs a DFA representing the constraint that (at least) 
# one the patterns must appear; a pattern is a sequence of token ids
ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
# word_count_builder constructs a DFA representing the constraint that 
# the generated text consists of a to b words; refer to the source code of
# WordCountBuilder for the definition of a word.
word_count_builder = ctrlg.WordCountBuilder(tokenizer, vocab_size)

dfa_graphs = []

# constraint 1:
# one of ' riding a bike', ' ride bikes', ' rides a bike', ' biking', ' bikes' has to appear
# AND one of ' park', ' beach' has to appear
keyphrases = [[' X ', ' Y ', ' V2 '],
            [' ->'],
             [' X ', ' Y ', ' V2 ']]


for keyphrase in keyphrases:
    patterns = [tokenizer.encode(x) for x in keyphrase]
    dfa_graphs.append(ac_builder.build(patterns))

# constraint 2: generate exactly 10 words
# word_count_builder constructs a DFA representing the constraint that 
# the generated text must contain a to b words
a, b = 3, 20
dfa_graphs.append(word_count_builder.build(a, b))

# taking the intersection of the DFAs, i.e., "logical and" of the constraints.
# This function also minimizes the constructed DFA, which is mainly CPU-based operations;
# Due to its pure python implemenation, DFA minimization can be slow for complex constraints
dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')

# compile the dfa_graph for efficient GPU execution
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)
##################################### DFA Construction #####################################


##################################### token length #####################################
# specify the min_new_tokens and max_new_tokens to be generated (excluding
# the prefix and suffix) make sure that the numbers here would not conflict
# with the given constraint: e.g. ask the model to generate 10 words with
# max_new_tokens = 8
min_new_tokens = 5
max_new_tokens = 32
##################################### token length #####################################

In [4]:
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id


##################################### prefix, suffix, prompt #####################################
prefix = ''
#"The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%.Is ringing alarm more likely than silent alarm overall?Guidance: Address the question by following the steps below:\nStep 1) Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas. Step 2) Determine the query type: Identify the type of query implied by the main question. Choices include 'marginal probability', 'conditional probability', 'explaining away effect', 'backdoor adjustment set', 'average treatment effect', 'collider bias', 'normal counterfactual question', 'average treatment effect on treated', 'natural direct effect' or 'natural indirect effect'. Your answer should only be a term from the list above, enclosed in quotation marks. Step 3) Formalize the query: Translate the query into its formal mathematical expression based on its type, utilizing the 'do(·)' notation or counterfactual notations as needed. Step 4) Extract all the available data. Your answer should contain nothing but marginal probabilities and conditional probabilities in the form 'P(...)=...' or 'P(...|...)=...', each probability being separated by a semicolon. Stick to the previously mentioned denotations for the variables. Step 5) Given all the information above, deduce the estimand using skills such as do-calculus, counterfactual prediction, and the basics of probabilities. Answer step by step. Step 6) Insert the relevant data in Step 4 into the estimand, perform basic arithmetic calculations, and derive the final answer. There is an identifiable answer. Answer step by step. \nBased on all the reasoning above, output one word to answer the initial question." # generate text starting with nothing
suffix = '<|endoftext|>' # generate text ending with '<|endoftext|>'; a suffix must end with the eos token
#prompt = "The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%.Is ringing alarm more likely than silent alarm overall?Guidance: Address the question by following the steps below:\nStep 1) Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas. Step 2) Determine the query type: Identify the type of query implied by the main question. Choices include 'marginal probability', 'conditional probability', 'explaining away effect', 'backdoor adjustment set', 'average treatment effect', 'collider bias', 'normal counterfactual question', 'average treatment effect on treated', 'natural direct effect' or 'natural indirect effect'. Your answer should only be a term from the list above, enclosed in quotation marks. Step 3) Formalize the query: Translate the query into its formal mathematical expression based on its type, utilizing the 'do(·)' notation or counterfactual notations as needed. Step 4) Extract all the available data. Your answer should contain nothing but marginal probabilities and conditional probabilities in the form 'P(...)=...' or 'P(...|...)=...', each probability being separated by a semicolon. Stick to the previously mentioned denotations for the variables. Step 5) Given all the information above, deduce the estimand using skills such as do-calculus, counterfactual prediction, and the basics of probabilities. Answer step by step. Step 6) Insert the relevant data in Step 4 into the estimand, perform basic arithmetic calculations, and derive the final answer. There is an identifiable answer. Answer step by step:<|endoftext|>"
prompt = "The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas.<|endoftext|>"
# prompt the base model with the '<|endoftext|>' token

prefix_ids = tokenizer.encode(prefix)
suffix_ids = tokenizer.encode(suffix)
prompt_ids = tokenizer.encode(prompt)
##################################### prefix, suffix, prompt #####################################


##################################### DFA Construction #####################################
# ac_builder constructs a DFA representing the constraint that (at least) 
# one the patterns must appear; a pattern is a sequence of token ids
#ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)

# word_count_builder constructs a DFA representing the constraint that 
# the generated text consists of a to b words; refer to the source code of
# WordCountBuilder for the definition of a word.
#word_count_builder = ctrlg.WordCountBuilder(tokenizer, vocab_size)

#dfa_graphs = []

# constraint 1:
# one of ' riding a bike', ' ride bikes', ' rides a bike', ' biking', ' bikes' has to appear
# AND one of ' park', ' beach' has to appear
#keyphrases = [[' riding a bike', ' ride bikes', ' rides a bike', ' biking', ' bikes'],
#            [' park', ' beach']]
#for keyphrase in keyphrases:
#    patterns = [tokenizer.encode(x) for x in keyphrase]
#    dfa_graphs.append(ac_builder.build(patterns))

# constraint 2: generate exactly 10 words
# word_count_builder constructs a DFA representing the constraint that 
# the generated text must contain a to b words
#a, b = 10, 10
#dfa_graphs.append(word_count_builder.build(a, b))

# taking the intersection of the DFAs, i.e., "logical and" of the constraints.
# This function also minimizes the constructed DFA, which is mainly CPU-based operations;
# Due to its pure python implemenation, DFA minimization can be slow for complex constraints
#dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')

# compile the dfa_graph for efficient GPU execution
#dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)
numbers = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

dfa_graph = {
    "edges": [
    (0, 1, ctrlg.populate_edge(["Step"], vocab_size, tokenizer)),
    (1, 2, ctrlg.populate_edge(["1)"], vocab_size, tokenizer)),
    (2, 3, ctrlg.populate_edge(letters, vocab_size, tokenizer)),
    (3, 3, ctrlg.populate_edge(numbers, vocab_size, tokenizer)),
    (3, 2, ctrlg.populate_edge(["->"], vocab_size, tokenizer)),
    (3, 4, ctrlg.populate_edge(["",".",",",";"], vocab_size, tokenizer)),
    (4, 5, ctrlg.populate_edge(["Step"], vocab_size, tokenizer)),
    (5, 6, ctrlg.populate_edge(["2)"], vocab_size, tokenizer)),
    (6, 7, ctrlg.populate_edge(['"'], vocab_size, tokenizer)),
    (7, 8, ctrlg.populate_edge(['marginal', 'conditional'], vocab_size, tokenizer)),
    (8, 9, ctrlg.populate_edge(['probability'], vocab_size, tokenizer)),
    (7, 11, ctrlg.populate_edge(['explaining'], vocab_size, tokenizer)),
    (11, 12, ctrlg.populate_edge(['away'], vocab_size, tokenizer)),
    (12, 9, ctrlg.populate_edge(['effect'], vocab_size, tokenizer)),
    (7, 13, ctrlg.populate_edge(['backdoor'], vocab_size, tokenizer)),
    (13, 14, ctrlg.populate_edge(['adjustment'], vocab_size, tokenizer)),
    (14, 9, ctrlg.populate_edge(['set'], vocab_size, tokenizer)),
    (7, 15, ctrlg.populate_edge(['average'], vocab_size, tokenizer)),
    (15, 12, ctrlg.populate_edge(['treatment'], vocab_size, tokenizer)),
    (7, 16, ctrlg.populate_edge(['collider'], vocab_size, tokenizer)),
    (16, 9, ctrlg.populate_edge(['bias'], vocab_size, tokenizer)),
    (7, 17, ctrlg.populate_edge(['normal'], vocab_size, tokenizer)),
    (17, 18, ctrlg.populate_edge(['counterfactual'], vocab_size, tokenizer)),
    (18, 9, ctrlg.populate_edge(['question'], vocab_size, tokenizer)),
    (7, 19, ctrlg.populate_edge(['average'], vocab_size, tokenizer)),
    (19, 20, ctrlg.populate_edge(['treatment'], vocab_size, tokenizer)),
    (20, 21, ctrlg.populate_edge(['effect'], vocab_size, tokenizer)),
    (21, 22, ctrlg.populate_edge(['on'], vocab_size, tokenizer)),
    (22, 9, ctrlg.populate_edge(['treated'], vocab_size, tokenizer)),
    (7, 23, ctrlg.populate_edge(['natural'], vocab_size, tokenizer)),
    (23, 12, ctrlg.populate_edge(['direct','indirect'], vocab_size, tokenizer)),
    (9, 10, ctrlg.populate_edge(['"'], vocab_size, tokenizer)),
    (10, 24, ctrlg.populate_edge(["Step"], vocab_size, tokenizer)),
    (24, 25, ctrlg.populate_edge(["3)"], vocab_size, tokenizer)),
    (25, 25, ctrlg.populate_edge(vocab_size=vocab_size, ALL=True)),
    (25, 26, ctrlg.populate_edge(["E","P"], vocab_size, tokenizer)),
    (26, 27, ctrlg.populate_edge(["(","["], vocab_size, tokenizer)),
    (27, 28, ctrlg.populate_edge(letters, vocab_size, tokenizer)),
    (28, 29, ctrlg.populate_edge(numbers+[""], vocab_size, tokenizer)),
    (29, 27, ctrlg.populate_edge([",",", "], vocab_size, tokenizer)),
    (29, 30, ctrlg.populate_edge(["|"], vocab_size, tokenizer)),
    (30, 31, ctrlg.populate_edge(letters, vocab_size, tokenizer)),
    (31, 32, ctrlg.populate_edge(numbers+[""], vocab_size, tokenizer)),
    (32, 30, ctrlg.populate_edge([",",", "], vocab_size, tokenizer)),
    (32, 33, ctrlg.populate_edge([")"], vocab_size, tokenizer)),
    (29, 33, ctrlg.populate_edge([")"], vocab_size, tokenizer)),
    (29, 34, ctrlg.populate_edge(["do"], vocab_size, tokenizer)),
    (34, 35, ctrlg.populate_edge(["("], vocab_size, tokenizer)),
    (35, 36, ctrlg.populate_edge(letters, vocab_size, tokenizer)),
    (36, 37, ctrlg.populate_edge(numbers+[""], vocab_size, tokenizer)),
    (37, 38, ctrlg.populate_edge(["="], vocab_size, tokenizer)),
    (38, 39, ctrlg.populate_edge(["0","1"], vocab_size, tokenizer)),
    (39, 40, ctrlg.populate_edge([")"], vocab_size, tokenizer)),
    (40, 41, ctrlg.populate_edge([","], vocab_size, tokenizer)),
    (41, 42, ctrlg.populate_edge(["do"], vocab_size, tokenizer)),
    (42, 35, ctrlg.populate_edge(["("], vocab_size, tokenizer)),
    (39, 35, ctrlg.populate_edge([","], vocab_size, tokenizer)),
    (42, 37, ctrlg.populate_edge([")"], vocab_size, tokenizer)),
    (37, 43, ctrlg.populate_edge(["]",""], vocab_size, tokenizer)),
    (43, 29, ctrlg.populate_edge([".", "+", "-", "*", "/", ""], vocab_size, tokenizer)),
    ],
    "initial_state": 0,
    "accept_states": set([43]),
}

#dfa_graphs.append(dfa_graph1)
#dfa_graphs.append(dfa_graph2)

#dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='concatenation')

# compile the dfa_graph for efficient GPU execution
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

##################################### DFA Construction #####################################


##################################### token length #####################################
# specify the min_new_tokens and max_new_tokens to be generated (excluding
# the prefix and suffix) make sure that the numbers here would not conflict
# with the given constraint: e.g. ask the model to generate 10 words with
# max_new_tokens = 8
min_new_tokens = 5
max_new_tokens = 100
##################################### token length #####################################

**Step 3. generate with constraints.**

Due to the use of @torch.compile, the first run of the following functions could be significantly slower than the later runs.

In [18]:
# initialze the constraints logits processor
# Note: this part pre-computes & cache certain conditional probability tables;
# one simple optimization is to re-use the same constraint_logits_processor for
# base_model.generate if the constraints do not change.
constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
    hmm_model, dfa_model,
    min_new_tokens, max_new_tokens,
    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


# set beam_size for beam search; usually the larger the beam_size the
# higher the generation quality
beam_size = 64

# set the hmm_batch_size depending on the resource available;
# uses more memory with larger hmm_batch_size but attains best speed 
# when it is set to beam_size
constraint_logits_processor.hmm_batch_size = beam_size

# generate with beam search
input_ids = torch.tensor([prompt_ids], device=device)
outputs = base_model.generate(
        input_ids=input_ids, do_sample=False, length_penalty=0.2,
        num_beams=beam_size, num_return_sequences=beam_size,
        min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
        logits_processor=LogitsProcessorList([constraint_logits_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


**Step 4. extract & rank outputs via the base model.**

In [5]:
tokenizer.encode('A')

[1, 319]

In [3]:
tokenizer.decode([2])

'</s>'

In [19]:
# extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

# rank the generated ids by the base_model probability
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids)

# print top 10 outputs
for idx, generated in enumerate(generated_ids[:10]):
    print(f'{idx}. ' + tokenizer.decode(prefix_ids, skip_special_tokens=True) + \
          '\033[1m' + tokenizer.decode(generated, skip_special_tokens=True) + '\033[0m' + \
          tokenizer.decode(suffix_ids, skip_special_tokens=True))

0. The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas:[1m[Previous Chapter] [Table of Contents] [Next Chapter]

Transmigrator Meets Reincarnator

Chapter -> X [0m
1. The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm c

**Step 5. try some other constraints! (example constraint 2)**

In [16]:
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id


prefix = ' on a fine sunny' # generate text starting with ' on a fine sunny'
suffix = ' in the park.<|endoftext|>' # generate text ending with ' in the park.<|endoftext|>'
prompt = '<|endoftext|> on a fine sunny' # prompt the base model with the '<|endoftext|>' token and the prefix

prefix_ids = tokenizer.encode(prefix)
suffix_ids = tokenizer.encode(suffix)
prompt_ids = tokenizer.encode(prompt)


ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
word_count_builder = ctrlg.WordCountBuilder(tokenizer, vocab_size)

dfa_graphs = []
# constraint 1:
# one of ' girl', ' boy', ' girls', ' boys', ' children' AND
# one of ' dogs', ' cats', ' dog', ' cat' have to appear
# in the GIVEN ORDER.

keyphrases = [[' girl', ' boy', ' girls', ' boys', ' children'],
            [' dogs', ' cats', ' dog', ' cat']]
for keyphrase in keyphrases:
    patterns = [tokenizer.encode(x) for x in keyphrase]
    dfa_graphs.append(ac_builder.build(patterns))
# concatenate the patterns so they appear in the given order
dfa_graphs = [ctrlg.DFA_concatenate(dfa_graphs)]

# constraint 2: generate 7 - 12 words
a, b = 7, 12
dfa_graphs.append(word_count_builder.build(a, b))

dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)


min_new_tokens = 5
max_new_tokens = 32


# initialze the constraints logits processor
constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
    hmm_model, dfa_model,
    min_new_tokens, max_new_tokens,
    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


beam_size = 128
constraint_logits_processor.hmm_batch_size = beam_size
input_ids = torch.tensor([prompt_ids], device=device)
# generate with beam search
outputs = base_model.generate(
        input_ids=input_ids, do_sample=False,
        num_beams=beam_size, num_return_sequences=beam_size,
        min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
        logits_processor=LogitsProcessorList([constraint_logits_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )

# extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

# rank the generated ids by the base_model probability
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids)

# print top 10 outputs
for idx, generated in enumerate(generated_ids[:10]):
    print(f'{idx}. ' + tokenizer.decode(prefix_ids, skip_special_tokens=True) + \
          '\033[1m' + tokenizer.decode(generated, skip_special_tokens=True) + '\033[0m' + \
          tokenizer.decode(suffix_ids, skip_special_tokens=True))

0.  on a fine sunny[1m day a young girl is walking her dog[0m in the park.
1.  on a fine sunny[1m day a young boy and his dog are playing[0m in the park.
2.  on a fine sunny[1m day a young boy is playing with his dog[0m in the park.
3.  on a fine sunny[1m day a boy and his dog are playing[0m in the park.
4.  on a fine sunny[1m day a young boy and his dog are walking[0m in the park.
5.  on a fine sunny[1m day a young girl and her dog are playing[0m in the park.
6.  on a fine sunny[1m day a young girl and her dog are walking[0m in the park.
7.  on a fine sunny[1m day a young girl is playing with her dog[0m in the park.
8.  on a fine sunny[1m day a girl is walking her dog[0m in the park.
9.  on a fine sunny[1m day a young boy and his dog are relaxing[0m in the park.


### **Part B**. Ctrl-G on TULU2-7B (more computation required)

Step 1. load pretrained models.

In [2]:
!pip3 install protobuf

Collecting protobuf
  Downloading protobuf-5.28.0-cp38-abi3-manylinux2014_x86_64.whl (316 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.6/316.6 KB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: protobuf
Successfully installed protobuf-5.28.0


In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # set your cuda device
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
import ctrlg
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

device = 'cuda'


# load the pretrained base_model and hmm_model;
BASE_MODEL_PATH = f'ctrlg/tulu2-7b_writing-prompts'
HMM_MODEL_PATH = f'ctrlg/hmm_tulu2-7b_writing-prompts_32768'

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH).to(device)
base_model.eval()
base_model.half() # fp16 inference
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
hmm_model = ctrlg.HMM.from_pretrained(HMM_MODEL_PATH).to(device)



Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Step 2. specify logical constraints as DFAs.

In [2]:
## ORIGINAL VERSION ##
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id #eos = end of sentence

prefix = 'Once upon a time, in a land far, far away, there was a kingdom. The kingdom was'
suffix = 'beautiful buildings. The people of this kingdom were known for their kindness and generosity, always ready to lend a helping hand.</s>'
soft_constraint = ' in fairytale style' # use empty string for no soft constraint
prompt = f'<|user|>\nContinue the given text{soft_constraint}:\n{prefix}\n<|assistant|>\n'

prefix_ids = tokenizer.encode(prefix)[1:]
suffix_ids = tokenizer.encode(suffix)[1:]
prompt_ids = tokenizer.encode(prompt)

ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
eos_builder = ctrlg.EOSBuilder(vocab_size, eos_token_id)

dfa_graphs = []
keyphrases = [['towering'], ['reach the sky'], ['reflected'], ['lake']]
for keyphrase in keyphrases:
    patterns = [tokenizer.encode(x)[1:] for x in keyphrase]
    dfa_graphs.append(ac_builder.build(patterns))
dfa_graphs.append(eos_builder.build())

dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')
dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

min_new_tokens = 16
max_new_tokens = 32

In [28]:
import pandas as pd
data = pd.read_csv('/media/data/bazaluk/ctrlg_tulu2/data.csv')
data['pred_graphs'] = ''

## MY VERSION ##
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id

#prefix = "The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas:"
#suffix = '</s>'
#soft_constraint = '' # use empty string for no soft constraint
#prompt = f'<|user|>\n"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in "var1 -> var2" format, separated by commas.<|endoftext|>"{soft_constraint}:\n{prefix}\n<|assistant|>\n'

prefix = data['prefix'].iloc[2000]
d_prompt = data['prompt'].iloc[2000]
suffix = '</s>'
soft_constraint = '' # use empty string for no soft constraint
prompt = f'<|user|>\n"{d_prompt}<|endoftext|>"{soft_constraint}:\n{prefix}\n<|assistant|>\n'


prefix_ids = tokenizer.encode(prefix)[1:]
suffix_ids = tokenizer.encode(suffix)[1:]
prompt_ids = tokenizer.encode(prompt)

#ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
ac_builder = ctrlg.NewAhoCorasickBuilder(vocab_size)
eos_builder = ctrlg.EOSBuilder(vocab_size, eos_token_id)
trivial = ctrlg.TrivialBuilder(vocab_size, eos_token_id)

dfa_graphs = []
#keyphrases = [['towering'], ['reach the sky'], ['reflected'], ['lake']]
keyphrases = [[' X ', ' Y ', ' V2 ', ' V3 ', ' V4 ', ' V5 '],
                [' ->'],
                 [' X ', ' Y ', ' V2 ', ' V3 ', ' V4 ', ' V5 ']
             ]

for keyphrase in keyphrases:
    patterns = [tokenizer.encode(x)[1:] for x in keyphrase]
    dfa_graphs.append(ac_builder.build(patterns))
dfa_graphs.append(eos_builder.build())


dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')

#######################################
#qd faz o grafo na mao eu acho q ele mistura os tokens q tao na lista - tem q conferir

dfa_graph = {
    "edges": [
    (0, 2, ctrlg.populate_edge(['X', 'Y'], vocab_size, tokenizer)),
    (0, 1, ctrlg.populate_edge(['V'], vocab_size, tokenizer)),
    (1, 2, ctrlg.populate_edge(['1', '2', '3', '4', '5'], vocab_size, tokenizer)),
    (2, 3, ctrlg.populate_edge(['->'], vocab_size, tokenizer)),
    (3, 4, ctrlg.populate_edge(['V'], vocab_size, tokenizer)),
    (3, 5, ctrlg.populate_edge(['X','Y'], vocab_size, tokenizer)),
    (4, 5, ctrlg.populate_edge(['1', '2', '3', '4', '5'], vocab_size, tokenizer)),
    (5, 0, ctrlg.populate_edge([','], vocab_size, tokenizer)),
    ],
    "initial_state": 0,
    "accept_states": set([5]),
}
########################################

dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

min_new_tokens = 1
max_new_tokens = 32

In [25]:
import pandas as pd
data = pd.read_csv('/media/data/bazaluk/ctrlg_tulu2/data.csv')
data['pred_graphs'] = ''

## MY VERSION ##
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id

#prefix = "The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in 'var1 -> var2' format, separated by commas:"
#suffix = '</s>'
#soft_constraint = '' # use empty string for no soft constraint
#prompt = f'<|user|>\n"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in "var1 -> var2" format, separated by commas.<|endoftext|>"{soft_constraint}:\n{prefix}\n<|assistant|>\n'

prefix = data['prefix'].iloc[2000]
d_prompt = data['prompt'].iloc[2000]
suffix = '</s>'
soft_constraint = '' # use empty string for no soft constraint
prompt = f'<|user|>\n"{d_prompt}<|endoftext|>"{soft_constraint}:\n{prefix}\n<|assistant|>\n'


prefix_ids = tokenizer.encode(prefix)[1:]
suffix_ids = tokenizer.encode(suffix)[1:]
prompt_ids = tokenizer.encode(prompt)

#ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
ac_builder = ctrlg.NewAhoCorasickBuilder(vocab_size)
eos_builder = ctrlg.EOSBuilder(vocab_size, eos_token_id)
trivial = ctrlg.TrivialBuilder(tokenizer,vocab_size, eos_token_id)

dfa_graph = trivial.build()


dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

min_new_tokens = 1
max_new_tokens = 40

In [8]:
tokenizer.encode(',')

[1, 1919]

Step 3. generate with constraints.

Due to the use of @torch.compile, the first run of the following functions could be significantly slower than the later runs.

In [29]:
dic = {} #the keys are the temperatures and the values are a list of the 10 predicted graphs
temp = [1,10,50,100]

# initialze the constraints logits processor
constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
    hmm_model, dfa_model,
    min_new_tokens, max_new_tokens,
    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


# set the hmm_batch_size & temperature
beam_size = 32 # sample 128 sequences
temperature = 0.7
constraint_logits_processor.hmm_batch_size = beam_size
constraint_logits_processor.temperature = temperature



# generate with sampling, temperature=0.7
input_ids = torch.tensor([prompt_ids], device=device)
outputs = base_model.generate(
        input_ids=input_ids, do_sample=True,
        num_return_sequences=beam_size, 
        min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
        logits_processor=LogitsProcessorList([constraint_logits_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )


# extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

# filter 75% of the generated ids by how well they connect with the suffix
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids,
                                            suffix_logits_only=True, suffix_length_cap=5)[:32]
# rank the generated ids by the base_model for higher quality
generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids)

dic[temperature] = []
# print top 10 outputs
for idx, generated in enumerate(generated_ids[:5]):
    print(f'{idx}. ' + tokenizer.decode(generated, skip_special_tokens=True) + \
          tokenizer.decode(suffix_ids, skip_special_tokens=True))
    #dic[temperature].append(tokenizer.decode(generated, skip_special_tokens=True) + \
    #      tokenizer.decode(suffix_ids, skip_special_tokens=True))
    #print(f'{idx}. ' + tokenizer.decode(prefix_ids, skip_special_tokens=True) + \
    #      ' ' + '\033[1m' + tokenizer.decode(generated, skip_special_tokens=True) + '\033[0m' + ' ' + \
    #      tokenizer.decode(suffix_ids, skip_special_tokens=True))

#data.at[i, 'pred_graphs'] = dic

0. V1 -> X
1. V1 -> X
2. V1 -> X
3. V1 -> X
4. V1 -> X


In [20]:
pd.set_option('display.max_colwidth', None)
data.iloc[1001]

Unnamed: 0                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   1001
prompt                                                                                                                                                             Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V1 to represent gender. Use X to represent smoking. Use V3 to represent tar deposit. Use Y to represent lung cancer. The diagram should simply consist of edges denoted in "var1 -> var2" format, separated by commas.
prefix         The overall probabili

In [6]:
a = {}
a['x']=[]
a['x'].append(2)
a['x'].append(3)
a

{'x': [2, 3]}

In [None]:
df.at['linha', 'col']

In [88]:
import pandas as pd
data = pd.read_json('/home/bazaluk/master_project/cladder.json')
data = data.dropna().reset_index(drop=True)

data = data.assign(step0 = lambda x: (x['reasoning']))
df = data['step0'].apply(lambda x: (x.get("step0")))
data['step0'] = df

data = data.assign(step1 = lambda x: (x['reasoning']))
df = data['step1'].apply(lambda x: (x.get("step1")))
data['step1'] = df

#step2) query type
data = data.assign(step2 = lambda x: (x['meta']))
df = data['step2'].apply(lambda x: (x.get("query_type")))
data['step2'] = df

#step3) formalize query
data = data.assign(step3 = lambda x: (x['reasoning']))
df = data['step3'].apply(lambda x: (x.get("step2")))
data['step3'] = df

#step4) extract all available data
data = data.assign(step4 = lambda x: (x['reasoning']))
df = data['step4'].apply(lambda x: (x.get("step4")))
data['step4'] = df

#step5) deduce estimand
data = data.assign(step5 = lambda x: (x['reasoning']))
df = data['step5'].apply(lambda x: (x.get("step3")))
data['step5'] = df

#step6)Insert the relevant data in Step 4 into the estimand, perform basic arithmetic calculations
data = data.assign(step6 = lambda x: (x['reasoning']))
df = data['step6'].apply(lambda x: (x.get("step5")))
data['step6'] = df

#end) derive the final answer
data = data.assign(end = lambda x: (x['reasoning']))
df = data['end'].apply(lambda x: (x.get("end")))
data['end'] = df

data

Unnamed: 0,question_id,desc_id,given_info,question,answer,meta,reasoning,step0,step1,step2,step3,step4,step5,step6,end
0,4,alarm-mediation-nde-model0-spec0-q0,For husbands that don't set the alarm and wive...,If we disregard the mediation effect through w...,yes,"{'story_id': 'alarm', 'graph_id': 'mediation',...",{'step0': 'Let X = husband; V2 = wife; Y = ala...,Let X = husband; V2 = wife; Y = alarm clock.,"X->V2,X->Y,V2->Y",nde,"E[Y_{X=1, V2=0} - Y_{X=0, V2=0}]","P(Y=1 | X=0, V2=0) = 0.08\nP(Y=1 | X=0, V2=1) ...","\sum_{V2=v} P(V2=v|X=0)*[P(Y=1|X=1,V2=v) - P(Y...",0.74 * (0.86 - 0.41) + 0.24 * (0.54 - 0.08) = ...,0.32 > 0
1,7,alarm-mediation-ate-model1-spec1-q1,"For husbands that don't set the alarm, the pro...",Will alarm set by husband decrease the chance ...,no,"{'story_id': 'alarm', 'graph_id': 'mediation',...",{'step0': 'Let X = husband; V2 = wife; Y = ala...,Let X = husband; V2 = wife; Y = alarm clock.,"X->V2,X->Y,V2->Y",ate,E[Y | do(X = 1)] - E[Y | do(X = 0)],P(Y=1 | X=0) = 0.26\nP(Y=1 | X=1) = 0.76,P(Y=1|X=1) - P(Y=1|X=0),0.76 - 0.26 = 0.50,0.50 > 0
2,8,alarm-mediation-marginal-model1-spec1-q0,The overall probability of alarm set by husban...,Is ringing alarm more likely than silent alarm...,yes,"{'story_id': 'alarm', 'graph_id': 'mediation',...",{'step0': 'Let X = husband; V2 = wife; Y = ala...,Let X = husband; V2 = wife; Y = alarm clock.,"X->V2,X->Y,V2->Y",marginal,P(Y),P(X=1) = 0.77\nP(Y=1 | X=0) = 0.26\nP(Y=1 | X=...,P(Y | X=1)*P(X=1) + P(Y | X=0)*P(X=0),0.77*0.76 - 0.23*0.26 = 0.64,0.64 > 0
3,15,alarm-mediation-ate-model3-spec3-q1,"For husbands that don't set the alarm, the pro...",Will alarm set by husband decrease the chance ...,no,"{'story_id': 'alarm', 'graph_id': 'mediation',...",{'step0': 'Let X = husband; V2 = wife; Y = ala...,Let X = husband; V2 = wife; Y = alarm clock.,"X->V2,X->Y,V2->Y",ate,E[Y | do(X = 1)] - E[Y | do(X = 0)],P(Y=1 | X=0) = 0.20\nP(Y=1 | X=1) = 0.68,P(Y=1|X=1) - P(Y=1|X=0),0.68 - 0.20 = 0.49,0.49 > 0
4,21,alarm-mediation-nie-model4-spec4-q0,For husbands that don't set the alarm and wive...,Does husband positively affect alarm clock thr...,no,"{'story_id': 'alarm', 'graph_id': 'mediation',...",{'step0': 'Let X = husband; V2 = wife; Y = ala...,Let X = husband; V2 = wife; Y = alarm clock.,"X->V2,X->Y,V2->Y",nie,"E[Y_{X=0, V2=1} - Y_{X=0, V2=0}]","P(Y=1 | X=0, V2=0) = 0.11\nP(Y=1 | X=0, V2=1) ...","\sum_{V2 = v} P(Y=1|X =0,V2 = v)*[P(V2 = v | X...",0.01 * (0.60 - 0.11)+ 0.61 * (0.92 - 0.46)= -0.29,-0.29 < 0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8527,31012,nonsense9-fork-det-counterfactual-model3842-sp...,We know that zuph or jyka causes glimx. We obs...,Would an individual is glimx if not zuph inste...,yes,"{'story_id': 'nonsense9', 'graph_id': 'fork', ...",{'step0': 'Let V2 = jyka; X = zuph; Y = glimx....,Let V2 = jyka; X = zuph; Y = glimx.,"X->Y,V2->Y",det-counterfactual,Y_{X=0} = 1 | V2=1,V2 = 1\nY = X or V2,"Solve for Y, given the evidence and the action",Y = 1 = 0 or 1,1
8528,31014,nonsense9-fork-det-counterfactual-model3842-sp...,We know that zuph or jyka causes glimx. We obs...,Would an individual is glimx if zuph instead o...,yes,"{'story_id': 'nonsense9', 'graph_id': 'fork', ...",{'step0': 'Let V2 = jyka; X = zuph; Y = glimx....,Let V2 = jyka; X = zuph; Y = glimx.,"X->Y,V2->Y",det-counterfactual,Y_{X=1} = 1 | V2=1,V2 = 1\nY = X or V2,"Solve for Y, given the evidence and the action",Y = 1 = 1 or 1,1
8529,31015,nonsense9-fork-det-counterfactual-model3842-sp...,We know that zuph or jyka causes glimx. We obs...,Would an individual is not glimx if zuph inste...,no,"{'story_id': 'nonsense9', 'graph_id': 'fork', ...",{'step0': 'Let V2 = jyka; X = zuph; Y = glimx....,Let V2 = jyka; X = zuph; Y = glimx.,"X->Y,V2->Y",det-counterfactual,Y_{X=1} = 0 | V2=1,V2 = 1\nY = X or V2,"Solve for Y, given the evidence and the action",Y = 1 = 1 or 1,0
8530,31016,nonsense9-fork-det-counterfactual-model3843-sp...,We know that zuph and jyka causes glimx. We ob...,Would an individual is glimx if not zuph inste...,no,"{'story_id': 'nonsense9', 'graph_id': 'fork', ...",{'step0': 'Let V2 = jyka; X = zuph; Y = glimx....,Let V2 = jyka; X = zuph; Y = glimx.,"X->Y,V2->Y",det-counterfactual,Y_{X=0} = 1 | V2=0,V2 = 0\nY = X and V2,"Solve for Y, given the evidence and the action",Y = 0 = 0 and 0,0


In [89]:
#create a list of real variable names and their respective representation e.g. [['X','wife'],['Y','husband']]
def list_representation(x):
    x = x.get('step0')
    x = x[4:len(x)-1] #erase "Let " and the last "."
    x = x.split('; ')

    aux = []
    for i in range(len(x)):
        aux.append(x[i].split(" = "))
    return aux

#receives a list from list_representation and returns the prompt with it
def prompt_repres(x):
    x = x.get('step0')
    x = x[4:len(x)-1] #erase "Let " and the last "."
    x = x.split('; ')

    list_rep = []
    for i in range(len(x)):
        list_rep.append(x[i].split(" = "))
    
    prompt = ""
    for i in range(len(list_rep)):
        prompt += "Use "+list_rep[i][0]+" to represent "+list_rep[i][1]+". "
    return prompt

def create_string(str_list):
    string = ""
    for i in range(str_list):
        string += str_list[i]
    return string

In [90]:
#create columnn with complete prompt to be given to the LLM
paraphrases = [
            "Imagine a self-contained, hypothetical world with only the following conditions, and without any unmentioned factors or causal relationships. ",
            "Think of a self-contained, hypothetical setting with just the specified conditions, and devoid of any unknown factors or causal connections. ",
            "Consider a self-contained, hypothetical world with solely the mentioned conditions, and is free from any hidden factors or cause-and-effect relationships. ",
            "Imagine a self-contained, hypothetical setting with merely the stated conditions, and absent any unmentioned factors or causative links. ",
            "Think of a self-contained, hypothetical world with only the given conditions, and is void of any unknown factors or causative relationships. ",
        ]

prompt_end0 = "Extract the causal graph: Identify the causal graph that \
depicts the relationships in the scenario. "
prompt_end1 = 'The diagram should simply consist of edges denoted in "var1 -> var2" format, separated by commas.'
prefix_end1 = 'The diagram should simply consist of edges denoted in "var1 -> var2" format, separated by commas:'
prefix_end0 = " Extract the causal graph: Identify the causal graph that \
depicts the relationships in the scenario. "

data['prompt'] = data.apply(lambda x: (prompt_end0+prompt_repres(x['reasoning'])+prompt_end1),axis=1)
data['prefix'] = data.apply(lambda x: (x['given_info']+prefix_end0+prompt_repres(x['reasoning'])+prefix_end1),axis=1)

#df = data[['prompt', 'formal_form', 'graph', 'query_type']]
#data['prefix'].iloc[0]
#data['prompt'].iloc[0]

#/media/data/bazaluk/ctrlg_tulu2/data.csv
data[['prompt','prefix','step1']].to_csv('/media/data/bazaluk/ctrlg_tulu2/data.csv')
#data['prompt'].iloc[0]
data['prefix'].iloc[1]

'For husbands that don\'t set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in "var1 -> var2" format, separated by commas:'

In [15]:
import pandas as pd
#######################################################################################################################
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # set your cuda device
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
import ctrlg
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

device = 'cuda'

# load the pretrained base_model and hmm_model;
BASE_MODEL_PATH = f'ctrlg/tulu2-7b_writing-prompts'
HMM_MODEL_PATH = f'ctrlg/hmm_tulu2-7b_writing-prompts_32768'

base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH).to(device)
base_model.eval()
base_model.half() # fp16 inference
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
hmm_model = ctrlg.HMM.from_pretrained(HMM_MODEL_PATH).to(device)

#######################################################################################################################
## MY VERSION ##
vocab_size = hmm_model.vocab_size
eos_token_id = hmm_model.eos_token_id

data = pd.read_csv('/media/data/bazaluk/ctrlg_tulu2/data.csv')
data['pred_graphs'] = ''

for i in range(3):
#for i in range(len(data)):
    vocab_size = hmm_model.vocab_size
    eos_token_id = hmm_model.eos_token_id

    prefix = data['prefix'].iloc[i]
    d_prompt = data['prompt'].iloc[i]
    suffix = '</s>'
    soft_constraint = '' # use empty string for no soft constraint
    prompt = f'<|user|>\n"{d_prompt}<|endoftext|>"{soft_constraint}:\n{prefix}\n<|assistant|>\n'


    prefix_ids = tokenizer.encode(prefix)[1:]
    suffix_ids = tokenizer.encode(suffix)[1:]
    prompt_ids = tokenizer.encode(prompt)

    ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
    eos_builder = ctrlg.EOSBuilder(vocab_size, eos_token_id)


    dfa_graphs = []
    #keyphrases = [['towering'], ['reach the sky'], ['reflected'], ['lake']]
    keyphrases = [[' X ', ' Y ', ' V2 ', ' V3 ', ' V4 ', ' V5 '],
                    [' ->'],
                     [' X ', ' Y ', ' V2 ', ' V3 ', ' V4 ', ' V5 ']
                 ]

    for keyphrase in keyphrases:
        patterns = [tokenizer.encode(x)[1:] for x in keyphrase]
        dfa_graphs.append(ac_builder.build(patterns))
    dfa_graphs.append(eos_builder.build())


    dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')
    dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

    min_new_tokens = 1
    max_new_tokens = 32
    
    ############################################################################
    # initialze the constraints logits processor
    constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
        hmm_model, dfa_model,
        min_new_tokens, max_new_tokens,
        prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids)


    # set the hmm_batch_size & temperature
    beam_size = 32 # sample 128 sequences
    temperature = 0.9
    constraint_logits_processor.hmm_batch_size = beam_size
    constraint_logits_processor.temperature = temperature


    # generate with sampling, temperature=0.7
    input_ids = torch.tensor([prompt_ids], device=device)
    outputs = base_model.generate(
            input_ids=input_ids, do_sample=True,
            num_return_sequences=beam_size, 
            min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
            logits_processor=LogitsProcessorList([constraint_logits_processor]),
            pad_token_id=tokenizer.eos_token_id,
        )


    # extract the generated ids; removing prompt ids; remove suffix ids that are (partially) generated
    generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)

    # filter 75% of the generated ids by how well they connect with the suffix
    generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids,
                                                suffix_logits_only=True, suffix_length_cap=5)[:32]
    # rank the generated ids by the base_model for higher quality
    generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids)

    # save top 10 outputs
    pred_graphs = []
    for idx, generated in enumerate(generated_ids[:10]):
        pred_graphs.append(tokenizer.decode(generated, skip_special_tokens=True) + \
              tokenizer.decode(suffix_ids, skip_special_tokens=True))
    data.at[i, 'pred_graphs'] = pred_graphs

data.to_csv('teste.csv')



Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [8]:
import pandas as pd
pd.set_option('display.max_colwidth', None)
pd.read_csv("teste.csv")

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,prompt,prefix,step1,pred_graphs
0,0,0,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","For husbands that don't set the alarm and wives that don't set the alarm, the probability of ringing alarm is 8%. For husbands that don't set the alarm and wives that set the alarm, the probability of ringing alarm is 54%. For husbands that set the alarm and wives that don't set the alarm, the probability of ringing alarm is 41%. For husbands that set the alarm and wives that set the alarm, the probability of ringing alarm is 86%. For husbands that don't set the alarm, the probability of alarm set by wife is 74%. For husbands that set the alarm, the probability of alarm set by wife is 24%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->V2,X->Y,V2->Y","{1: ['X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> Y', 'X -> Y'], 10: ['V2 -> Y', 'V2 -> Y', 'V2 -> Y', 'V2 -> Y', 'X -> V2', 'X -> V2', 'X -> V2', 'V2 -> X', 'V2 -> X', 'X -> Y'], 50: ['V2 -> X', 'X -> Y', 'X -> Y', 'X -> Y', 'X -> Y', 'X -> V1', 'X -> V1', 'Y -> X', 'Y -> X', 'Y -> X'], 100: ['V2 -> Y', 'V2 -> Y', 'V2 -> X', 'V2 -> Y , V1 -> X', 'X -> Y', 'X -> Y', 'V1 -> Y', 'V5 -> X , V4 -> X , V3 -> Y , V2 -> V5 , X -> X', 'Y -> X', 'Y -> X']}"
1,1,1,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->V2,X->Y,V2->Y","{1: ['X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> V2', 'X -> Y'], 10: ['V1 -> V2', 'X -> V2', 'V2 -> X', 'Y -> V2', 'Y -> V2', 'Y -> V2', 'V1 -> X', 'V1 -> X', 'V2 -> V2', 'X -> Y'], 50: ['X -> V2', 'V2 -> X', 'V2 -> X', 'Y -> V2', 'V2 -> X , V1 -> Y', 'V2 -> V2', 'X -> Y , X -> Y', 'X -> Y', 'X -> Y', 'X -> Y'], 100: ['V2 -> Y', 'X -> V2', 'Y -> V2', 'Y -> V2', 'Y -> V2', 'V1 -> X', 'X -> Y', 'X -> Y', 'Y -> X , X -> Y', 'Y -> V1']}"
2,2,2,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","The overall probability of alarm set by husband is 77%. For husbands that don't set the alarm, the probability of ringing alarm is 26%. For husbands that set the alarm, the probability of ringing alarm is 76%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->V2,X->Y,V2->Y",
3,3,3,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","For husbands that don't set the alarm, the probability of ringing alarm is 20%. For husbands that set the alarm, the probability of ringing alarm is 68%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->V2,X->Y,V2->Y",
4,4,4,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","For husbands that don't set the alarm and wives that don't set the alarm, the probability of ringing alarm is 11%. For husbands that don't set the alarm and wives that set the alarm, the probability of ringing alarm is 60%. For husbands that set the alarm and wives that don't set the alarm, the probability of ringing alarm is 46%. For husbands that set the alarm and wives that set the alarm, the probability of ringing alarm is 92%. For husbands that don't set the alarm, the probability of alarm set by wife is 61%. For husbands that set the alarm, the probability of alarm set by wife is 1%. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use X to represent husband. Use V2 to represent wife. Use Y to represent alarm clock. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->V2,X->Y,V2->Y",
...,...,...,...,...,...,...
8527,8527,8527,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","We know that zuph or jyka causes glimx. We observed an individual is jyka. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->Y,V2->Y",
8528,8528,8528,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","We know that zuph or jyka causes glimx. We observed an individual is jyka. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->Y,V2->Y",
8529,8529,8529,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","We know that zuph or jyka causes glimx. We observed an individual is jyka. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->Y,V2->Y",
8530,8530,8530,"Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas.","We know that zuph and jyka causes glimx. We observed an individual is not jyka. Extract the causal graph: Identify the causal graph that depicts the relationships in the scenario. Use V2 to represent jyka. Use X to represent zuph. Use Y to represent glimx. The diagram should simply consist of edges denoted in ""var1 -> var2"" format, separated by commas:","X->Y,V2->Y",


In [6]:
#[(k[0], k[1], v) for k, v in trans.items() if k[0] != tuple() or k[1] != tuple()]
a=[1,2,3]
b=[4,5,6]
[(a[i],b[i]) for i in range(3) if a[i] != 2 and b[i] != 6]

[(1, 4)]

In [10]:
a= {
  "Germany": "Berlin", 
  "Canada": "Ottawa", 
  "England": "London"
}
list(a.items())[0]

('Germany', 'Berlin')

In [7]:
a = [(((), (29871,)), 2)]
a[0][0][0]

()