## Setting up the COLAB runtime (user action required)

<a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This colab-friendly notebook is targeted at demoing the enforcer on LLAMA2. It can run on a free GPU on Google Colab.
Make sure that your runtime is set to GPU:

Menu Bar -> Runtime -> Change runtime type -> T4 GPU (at the time of writing this notebook). [Guide here](https://www.codesansar.com/deep-learning/using-free-gpu-tpu-google-colab.htm).

In [None]:
!pip install transformers torch lm-format-enforcer huggingface_hub optimum langchain langchain-experimental
!pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

# When running from source / developing the library, use this instead
# %load_ext autoreload
# %autoreload 2
# import sys
# import os
# sys.path.append(os.path.abspath('..'))
## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

Now we can create the model. This may take a few minutes.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = 'TheBloke/Llama-2-7b-Chat-GPTQ'
device = 'cuda'

if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
else:
    raise Exception('GPU not available')
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    # Required for batching example
    tokenizer.pad_token_id = tokenizer.eos_token_id


  from .autonotebook import tqdm as notebook_tqdm


If the previous cell executed successfully, you have propertly set up your Colab runtime and huggingface account!

## Setting up the prompt for the specific language model

We set up the prompting style according to the demo at https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/app.py . We simplify the implementation a bit as we don't need chat history for this demo.

In [None]:
def get_prompt(message: str, system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    # The first user input is _not_ stripped
    do_strip = False
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)

## Calling generate_enforced() instead of model.generate()

The main function is fairly straightforward, except for the optional parameters ```required_regex``` / ```required_str``` / ```required_json_schema``` which activate the appropriate ```CharacterLevelParser``` to be used by the format enforcer.

The implementation includes ```output_scores=True``` when calling ```generate_enforced()```, which returns diagnostic information about the enforcer's actions. We will use this to understand the results later in this notebook.

In [None]:
%load_ext autoreload
%autoreload 2
from typing import Tuple, Optional, Union, List
import pandas as pd
from lmformatenforcer import JsonSchemaParser, CharacterLevelParser, RegexParser, StringParser
from lmformatenforcer.integrations.transformers import generate_enforced, build_token_enforcer_tokenizer_data

StringOrManyStrings = Union[str, List[str]]

tokenizer_data = build_token_enforcer_tokenizer_data(tokenizer)

def run(message: StringOrManyStrings,
        system_prompt: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50,
        num_beams: int = 1,
        required_regex: Optional[str] = None,
        required_str: Optional[str] = None,
        required_json_schema: Optional[dict] = None,
        required_json_output: Optional[bool] = None) -> Tuple[StringOrManyStrings, Optional[pd.DataFrame]]:
    is_multi_message = isinstance(message, list)
    messages = message if is_multi_message else [message]
    prompts = [get_prompt(msg, system_prompt) for msg in messages]
    inputs = tokenizer(prompts, return_tensors='pt', add_special_tokens=False, return_token_type_ids=False, padding=is_multi_message).to(device)

    generate_kwargs = dict(
        inputs,
        # streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=num_beams,
        output_scores=True,
        return_dict_in_generate=True
    )

    parser: Optional[CharacterLevelParser] = None
    if required_regex:
        parser = RegexParser(required_regex)
    if required_str:
        parser = StringParser(required_str)
    if required_json_schema:
        parser = JsonSchemaParser(required_json_schema)
    if required_json_output:
        parser = JsonSchemaParser(None)

    if parser:
        output = generate_enforced(model, tokenizer_data, parser, **generate_kwargs)
    else:
        output = model.generate(**generate_kwargs)

    sequences = output['sequences']
    # skip_prompt=True doesn't work consistenly, so we hack around it.
    string_outputs = [tokenizer.decode(sequence, skip_special_tokens=True) for sequence in sequences]
    string_outputs = [string_output.replace(prompt[3:], '') for string_output, prompt in zip(string_outputs, prompts)]
    if parser and not is_multi_message:
        enforced_scores_dict = output.enforced_scores
        enforced_scores = pd.DataFrame(enforced_scores_dict)
        pd.set_option('display.width', 1000)
        pd.set_option('display.max_columns', 10)
        pd.set_option('display.max_rows', 999)
        pd.set_option('display.float_format', ' {:,.5f}'.format)
    else:
        enforced_scores = None
    return string_outputs if is_multi_message else string_outputs[0], enforced_scores


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## JSON Schema Use case

Now we demonstrate using ```JsonSchemaParser```. We create a pydantic model, generate the schema from it, and use that to enforce the format.
The output will always be in a format that can be parsed by the parser.

In [None]:
from pydantic import BaseModel
from IPython.display import display, Markdown
from typing import List

def display_header(text):
    display(Markdown(f'**{text}**'))

def display_content(text):
    display(Markdown(f'```\n{text}\n```'))

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
DEFAULT_MAX_NEW_TOKENS = 100

class AnswerFormat(BaseModel):
    first_name: str
    last_name: str
    year_of_birth: int
    num_seasons_in_nba: int

question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '
question_with_schema = f'{question}{AnswerFormat.schema_json()}'

display_header("Question:")
display_content(question_with_schema)

display_header("Answer, With json schema enforcing:")
result, enforced_scores = run(question_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_schema=AnswerFormat.schema())
display_content(result)

display_header("Answer, Without json schema enforcing:")
result, _ = run(question_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS)
display_content(result)

display_header("Answer, With json mode enforcing (json output, schemaless):")
result, _ = run(question_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_output=True)
display_content(result)



**Question:**

```
Please give me information about Michael Jordan. You MUST answer using the following json schema: {"title": "AnswerFormat", "type": "object", "properties": {"first_name": {"title": "First Name", "type": "string"}, "last_name": {"title": "Last Name", "type": "string"}, "year_of_birth": {"title": "Year Of Birth", "type": "integer"}, "num_seasons_in_nba": {"title": "Num Seasons In Nba", "type": "integer"}}, "required": ["first_name", "last_name", "year_of_birth", "num_seasons_in_nba"]}
```

**Answer, With json schema enforcing:**



```
  {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}
```

**Answer, Without json schema enforcing:**

```
  Of course! I'd be happy to help you with your question about Michael Jordan. Here's the information you requested in the format you specified:
{
"title": "AnswerFormat",
"type": "object",
"properties": {
"first_name": {
"title": "First Name",
"type": "string"
},
"last_name": {
"title": "Last Name",
"type": "string"
```

**Answer, With json mode enforcing (json output, schemaless):**

```
  {
"title": "AnswerFormat",
"type": "object",
"properties": {
"first_name": {
"title": "First Name",
"type": "string"
},
"last_name": {
"title": "Last Name",
"type": "string"
},
"year_of_birth": {
"title": "Year Of Birth",
"type": "integer"
},

```

## Understanding the results
Both runs used the exact same prompt, but only the enforced run produced a valid JSON output.
How did the enforcer cause the output to conform to the format? Lets look at the enforcer intervention table:


In [None]:
display_header("Enforcer intervention table")
display(enforced_scores)

**Enforcer intervention table**

Unnamed: 0,generated_token,generated_token_idx,generated_score,leading_token,leading_token_idx,leading_score
0,,29871,0.99982,,29871,0.99982
1,{,426,2e-05,Of,4587,0.95649
2,\n,13,0.99705,\n,13,0.99705
3,"""",29908,0.98198,"""",29908,0.98198
4,first,4102,0.0103,title,3257,0.55388
5,_,29918,0.99879,_,29918,0.99879
6,name,978,1.0,name,978,1.0
7,""":",1115,0.99994,""":",1115,0.99994
8,"""",376,0.99949,"""",376,0.99949
9,Michael,24083,0.99028,Michael,24083,0.99028


Token index 3 shows where the enforcer had to be aggressive:


```
generated_token  generated_token_idx  generated_score leading_token  leading_token_idx  leading_score

3            first                 4102         0.00032         title               3257       0.94433
```

The language model was trying to generate the word "title" (post softmax score of 0.94433) but the enforcer made the "first" token (post softmax score of 0.00032) be the main candidate instead.
This can be used to further improve the prompt engineering, as we generally want to avoid timesteps that cause the enforcer to be this aggressive, as it increases the likelyhood of hallucinations. For example, The [langchain project removes the "title" from the json schema](https://github.com/langchain-ai/langchain/blob/cfa2203c626a2287d60c1febeb3e3a68b77acd77/libs/langchain/langchain/output_parsers/pydantic.py#L40), probably for this reason.

## Regular Expression Use Case

Note that the ```RegexParser``` does not support the full regex syntax, as it uses [interegular](https://pypi.org/project/interegular/) under the hood. We will also make use of ```StringParser``` to diagnose the results.

In [None]:
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
MAX_MAX_NEW_TOKENS = 200
DEFAULT_MAX_NEW_TOKENS = 100
MAX_INPUT_TOKEN_LENGTH = 4000

date_regex = r'(0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}'
answer_regex = ' In mm/dd/yyyy format, Michael Jordan was born in ' + date_regex
question = 'When was Michael Jordan Born? Please answer in mm/dd/yyyy format.'
display_header("Question:")
display_content(question)

display_header("Without enforcing:")
result, _ = run(question, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS)
display_content(result)

print('\n----------------------------------------\n')

display_header(f"With regex force. Regex: ```{answer_regex}```")
result, enforced_scores = run(question, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_regex=answer_regex)
display_header("Language model output:")
display_content(result)
display_header("Enforcer intervention table:")
display(enforced_scores)

print('\n----------------------------------------\n')

display_header("With string force:")
answer = ' Michael Jordan was born in 02/17/1963'
result, enforced_scores = run(question, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_str=answer)
display_header("Language model output:")
display_content(result)
display_header("Enforcer intervention table:")
display(enforced_scores)

**Question:**

```
When was Michael Jordan Born? Please answer in mm/dd/yyyy format.
```

**Without enforcing:**

```
  Thank you for reaching out! I'm happy to help you with your question. However, I must inform you that Michael Jordan was born on February 17, 1963, which is not possible in the mm/dd/yyyy format as it is a date that falls in the 20th century. I apologize for any confusion. Is there anything else I can help you with?
```


----------------------------------------



**With regex force. Regex: ``` In mm/dd/yyyy format, Michael Jordan was born in (0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}```**

**Language model output:**

```
 In mm/dd/yyyy format, Michael Jordan was born in 12/23/1963
```

**Enforcer intervention table:**

Unnamed: 0,generated_token,generated_token_idx,generated_score,leading_token,leading_token_idx,leading_score
0,,29871,0.99993,,29871,0.99993
1,I,29902,0.0,I,306,0.59019
2,n,29876,0.0,apolog,27746,0.63013
3,,29871,0.00091,st,303,0.67129
4,mm,4317,2e-05,1,29896,0.53647
5,/,29914,0.99731,/,29914,0.99731
6,dd,1289,0.99882,dd,1289,0.99882
7,/,29914,0.99897,/,29914,0.99897
8,yyyy,18855,0.99702,yyyy,18855,0.99702
9,format,3402,0.99611,format,3402,0.99611



----------------------------------------



**With string force:**

**Language model output:**

```
 Michael Jordan was born in 02/17/1963
```

**Enforcer intervention table:**

Unnamed: 0,generated_token,generated_token_idx,generated_score,leading_token,leading_token_idx,leading_score
0,,29871,0.99993,,29871,0.99993
1,Michael,24083,0.0,I,306,0.5897
2,Jordan,18284,0.99998,Jordan,18284,0.99998
3,was,471,0.99966,was,471,0.99966
4,born,6345,0.99996,born,6345,0.99996
5,in,297,3e-05,on,373,0.99997
6,,29871,0.07737,Fort,7236,0.47024
7,0,29900,0.00822,1,29896,0.9727
8,2,29906,0.99314,2,29906,0.99314
9,/,29914,0.99944,/,29914,0.99944


The unenforced example does not adhere to the format requirements at all.

Note that in the string force section, the first digit that the LM wants to generate is a 1, not a 0 (which we would expect in 02/17/1963):

```
	generated_token	generated_token_idx	generated_score	leading_token	leading_token_idx	leading_score
7	0	29900	0.03939	1	29896	0.95441
```

This is likely due to the prompt not emphasizing the format requirements enough, causing the LM to want to output the year first (1963 starts with 1). In the regex example, we add a ```In mm/dd/yyyy format, ``` prefix to the regex, making 0 a pretty confident leader in that case:

```
	generated_token	generated_token_idx	generated_score	leading_token	leading_token_idx	leading_score
16	0	29900	0.91775	0	29900	0.91775
```

## Batching example

This is a simple example of using batching to generate multiple queries in parallel. All outputs will be in the correct format. Every timestep can filter different tokens for the different batch indices.

In [None]:
PLAYER_NAMES = ['Michael Jordan', 'Tim Duncan', 'Kobe Bryant', 'Kareem Abdul Jabbar']
question = 'Please give me information about {0}. You MUST answer using the following json schema: '
questions_with_schema = [f'{question.format(player_name)}{AnswerFormat.schema_json()}' for player_name in PLAYER_NAMES]

display_header("Batched Answers, With json schema enforcing:")
results, _ = run(questions_with_schema, system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_schema=AnswerFormat.schema())
for result in results:
    display_content(result)

**Batched Answers, With json schema enforcing:**

```
  { "first_name": "Michael", "last_name": "Jordan", "year_of_birth": 1963, "num_seasons_in_nba": 15 }
```

```
  { "first_name": "Timothy", "last_name": "Duncan", "year_of_birth": 1976, "num_seasons_in_nba": 19 }
```

```
  { "first_name": "Kobe", "last_name": "Bryant", "year_of_birth": 1978, "num_seasons_in_nba": 20 }
```

```
  { "first_name": "Kareem Abdul-Jabbar", "last_name": "Abdul-Jabbar", "year_of_birth": 1947, "num_seasons_in_nba": 20 }
```

# LangChain Integration

This is a simple example of how to integrate the enforcer into the LangChain project.


For a more complete explanation, see the [LM Format Enforcer Page in the Langchain documentation](https://python.langchain.com/docs/integrations/llms/lmformatenforcer_experimental).

This demo shows the JSON use case, the regex case is also supported.

In [None]:
from transformers import pipeline
from langchain_experimental.llms import LMFormatEnforcer

# We create a transformers pipeline to avoid loading the model twice, but we could also use LMFormatEnforcer.from_model_id()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=100)
langchain_pipeline = LMFormatEnforcer(pipeline=pipe, json_schema=AnswerFormat.schema())

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""

players = ['Michael Jordan', 'Larry Bird', 'Tim Duncan']
question = 'Please give me information about {}.'
prompts = [get_prompt(question.format(player), DEFAULT_SYSTEM_PROMPT) for player in players]

display_header('Call mode')
result = langchain_pipeline(prompts[0])
display_content(result)

display_header('Batched mode')
results = langchain_pipeline.generate(prompts)
for generation in results.generations:
    display_content(generation[0].text)

**Call mode**

```
   {

"last_name" : "Jordan",
"first_name" : "Michael",
"year_of_birth" : 1963,
"num_seasons_in_nba" : 15
}












```

**Batched mode**

```
   {

"last_name" : "Jordan",
"first_name" : "Michael",
"year_of_birth" : 1963,
"num_seasons_in_nba" : 15
}












```

```
 










{











"last_name": "Larry Bird",
"num_seasons_in_nba": 13,
"year_of_birth": 1956,
"first_name": "Larry"
}












```

```
   {

"first_name": "Tim",
"last_name": "Duncan",
"num_seasons_in_nba": 19,
"year_of_birth": 1976
}





```