Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error during constrained JSON generation. #612

Closed
willkurt opened this issue Feb 5, 2024 · 4 comments · Fixed by #619
Closed

Error during constrained JSON generation. #612

willkurt opened this issue Feb 5, 2024 · 4 comments · Fixed by #619
Labels

Comments

@willkurt
Copy link
Contributor

willkurt commented Feb 5, 2024

Describe the issue as clearly as possible:

I'm currently working on attempting to generate JSON using a JSON prompt and this is leading to strange behavior (and most importantly an error) when attempting to generate.

The model seems to be unable to generate valid json within the token limit, even though the Pydantic model I'm using is constrained to only 10 characters in the character field and an int. This issue also happened when I attempted unconstrained generation, which is why I added the constraint.

notes:

  • This issue a bit finicky. For examle, if I set max_length=100 and sampler to greedy I don't get the error, but if I change the sampler to multinomial I get it again. Both error at max_length=10.
  • Happens with both greedy and multinomial samplers.
  • The example code is run on Apple silicon using mps but I was also able to reproduce this error on a Linux box with a 4090 using the cuda/torch.float16 configuration as well.
  • I've tried this with non-structured prompts and generation did work, so something about this prompt is leading the model to this outcome.

Steps/code to reproduce the bug:

# Note, if reproing on a GPU make sure to change the device to 'cuda'
from datasets import load_dataset
import outlines
from outlines.generate.samplers import greedy, multinomial
from pydantic import BaseModel, constr
import torch
import json

device = 'mps'

dataset = load_dataset("gsm8k", "main")

model_mistral_7b = outlines.models.transformers(
    "mistralai/Mistral-7B-v0.1",
    device=device
)

# this prompt is designed to match the structure of the output I expect
# from outlines.
def test_prompt(question):
    examples_dicts=[
 {"question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
 "response": {
   "reasoning": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
   "answer": 6
  }},
    {"question": f"{question}",
     "response": {}}
    ]
    text_json = json.dumps(examples_dicts)
    return text_json.strip("[]}")[0:-1]

class Answer(BaseModel):
    reasoning: constr(max_length=10)
    answer: int

rng = torch.Generator(device="mps")
rng.manual_seed(1337)

generator = outlines.generate.json(model_mistral_7b, Answer, sampler=greedy)
generator(test_prompt(dataset['test']['question'][19]), max_tokens=512, rng=rng)

Expected result:

Here's an example output when the model *does* complete:


Answer(reasoning='Marissa walked 4 miles in 1 hour, so her speed was 4 miles per hour. She walked 2 miles in 1 hour, a', answer=4)


### Error message:

```shell
---------------------------------------------------------------------------
JSONDecodeError                           Traceback (most recent call last)
File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/pydantic/main.py:1052, in BaseModel.parse_raw(cls, b, content_type, encoding, proto, allow_pickle)
   1051 try:
-> 1052     obj = parse.load_str_bytes(
   1053         b,
   1054         proto=proto,
   1055         content_type=content_type,
   1056         encoding=encoding,
   1057         allow_pickle=allow_pickle,
   1058     )
   1059 except (ValueError, TypeError) as exc:

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/typing_extensions.py:2499, in deprecated.__call__.<locals>.wrapper(*args, **kwargs)
   2498 warnings.warn(msg, category=category, stacklevel=stacklevel + 1)
-> 2499 return arg(*args, **kwargs)

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/pydantic/deprecated/parse.py:49, in load_str_bytes(b, content_type, encoding, proto, allow_pickle, json_loads)
     48         b = b.decode(encoding)
---> 49     return json_loads(b)  # type: ignore
     50 elif proto == Protocol.pickle:

File ~/miniconda3/envs/outlines-dev/lib/python3.10/json/__init__.py:346, in loads(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)
    343 if (cls is None and object_hook is None and
    344         parse_int is None and parse_float is None and
    345         parse_constant is None and object_pairs_hook is None and not kw):
--> 346     return _default_decoder.decode(s)
    347 if cls is None:

File ~/miniconda3/envs/outlines-dev/lib/python3.10/json/decoder.py:337, in JSONDecoder.decode(self, s, _w)
    333 """Return the Python representation of ``s`` (a ``str`` instance
    334 containing a JSON document).
    335 
    336 """
--> 337 obj, end = self.raw_decode(s, idx=_w(s, 0).end())
    338 end = _w(s, end).end()

File ~/miniconda3/envs/outlines-dev/lib/python3.10/json/decoder.py:353, in JSONDecoder.raw_decode(self, s, idx)
    352 try:
--> 353     obj, end = self.scan_once(s, idx)
    354 except StopIteration as err:

JSONDecodeError: Expecting ',' delimiter: line 1 column 536 (char 535)

During handling of the above exception, another exception occurred:

ValidationError                           Traceback (most recent call last)
Cell In[19], line 40
     37 rng.manual_seed(1337)
     39 generator = outlines.generate.json(model_mistral_7b, Answer, sampler=greedy)
---> 40 generator(test_prompt(dataset['test']['question'][19]), max_tokens=512, rng=rng)

File ~/code/outlines/outlines/generate/api.py:236, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng, kv_cache)
    231 stripped = [
    232     self.strip_stop_sequences(sequence, stop_sequences)
    233     for sequence in generated
    234 ]
    235 try:
--> 236     formatted = [self.format_sequence(sequence) for sequence in stripped]
    237 except pyjson.decoder.JSONDecodeError:
    238     raise TypeError(
    239         "Could not format the output of the model into a dictionary or a Pydantic model."
    240         + " The model has likely exceeded its context length. Please try again using `constr` (for Pydantic)"
    241         + " and `maxLength` (for JSON Schema) to limit the length of the string fields. If this exception"
    242         + " is raised nevertheless please open an issue: https://github.com/outlines-dev/outlines/issues"
    243     )

File ~/code/outlines/outlines/generate/api.py:236, in <listcomp>(.0)
    231 stripped = [
    232     self.strip_stop_sequences(sequence, stop_sequences)
    233     for sequence in generated
    234 ]
    235 try:
--> 236     formatted = [self.format_sequence(sequence) for sequence in stripped]
    237 except pyjson.decoder.JSONDecodeError:
    238     raise TypeError(
    239         "Could not format the output of the model into a dictionary or a Pydantic model."
    240         + " The model has likely exceeded its context length. Please try again using `constr` (for Pydantic)"
    241         + " and `maxLength` (for JSON Schema) to limit the length of the string fields. If this exception"
    242         + " is raised nevertheless please open an issue: https://github.com/outlines-dev/outlines/issues"
    243     )

File ~/code/outlines/outlines/generate/api.py:425, in json.<locals>.<lambda>(x)
    423     regex_str = build_regex_from_object(schema)
    424     generator = regex(model, regex_str, max_tokens, sampler)
--> 425     generator.format_sequence = lambda x: schema_object.parse_raw(x)
    426 elif callable(schema_object):
    427     schema = pyjson.dumps(get_schema_from_signature(schema_object))

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/typing_extensions.py:2499, in deprecated.__call__.<locals>.wrapper(*args, **kwargs)
   2496 @functools.wraps(arg)
   2497 def wrapper(*args, **kwargs):
   2498     warnings.warn(msg, category=category, stacklevel=stacklevel + 1)
-> 2499     return arg(*args, **kwargs)

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/pydantic/main.py:1079, in BaseModel.parse_raw(cls, b, content_type, encoding, proto, allow_pickle)
   1072     # ctx is missing here, but since we've added `input` to the error, we're not pretending it's the same
   1073     error: pydantic_core.InitErrorDetails = {
   1074         # The type: ignore on the next line is to ignore the requirement of LiteralString
   1075         'type': pydantic_core.PydanticCustomError(type_str, str(exc)),  # type: ignore
   1076         'loc': ('__root__',),
   1077         'input': b,
   1078     }
-> 1079     raise pydantic_core.ValidationError.from_exception_data(cls.__name__, [error])
   1080 return cls.model_validate(obj)

ValidationError: 1 validation error for Answer
__root__
  Expecting ',' delimiter: line 1 column 536 (char 535) [type=value_error.jsondecode, input_value='{"reasoning": "Marissa\'...00000000000000000000000', input_type=str]

Outlines/Python version information:

Version information

``` 0.0.25.dev7+g8a0bafc.d20240123 Python 3.10.0 | packaged by conda-forge | (default, Nov 20 2021, 02:27:15) [Clang 11.1.0 ] accelerate==0.26.1 aiohttp @ file:///Users/runner/miniforge3/conda-bld/aiohttp_1701099674487/work aiosignal @ file:///home/conda/feedstock_root/build_artifacts/aiosignal_1667935791922/work annotated-types @ file:///home/conda/feedstock_root/build_artifacts/annotated-types_1696634205638/work anyio==4.2.0 appnope==0.1.3 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens==2.4.1 async-lru==2.0.4 async-timeout @ file:///home/conda/feedstock_root/build_artifacts/async-timeout_1691763562544/work attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1704011227531/work Babel==2.14.0 beartype==0.15.0 beautifulsoup4==4.12.3 bleach==6.1.0 Brotli @ file:///Users/runner/miniforge3/conda-bld/brotli-split_1695989934239/work certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1700303426725/work/certifi cffi @ file:///Users/runner/miniforge3/conda-bld/cffi_1696001737800/work cfgv @ file:///home/conda/feedstock_root/build_artifacts/cfgv_1629909281805/work chardet==5.2.0 charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work cloudpickle==3.0.0 colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work comm==0.2.1 coverage==7.4.0 datasets @ file:///home/conda/feedstock_root/build_artifacts/datasets_1704319050587/work debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 diff_cover==8.0.2 dill @ file:///home/conda/feedstock_root/build_artifacts/dill_1690101045195/work diskcache==5.6.3 distlib @ file:///home/conda/feedstock_root/build_artifacts/distlib_1702383208639/work exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work executing==2.0.1 fastjsonschema==2.19.1 filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1698714947081/work fqdn==1.5.1 frozenlist @ file:///Users/runner/miniforge3/conda-bld/frozenlist_1702645565720/work fsspec @ file:///home/conda/feedstock_root/build_artifacts/fsspec_1697919321618/work huggingface_hub @ file:///home/conda/feedstock_root/build_artifacts/huggingface_hub_1704484084177/work identify @ file:///home/conda/feedstock_root/build_artifacts/identify_1701927326014/work idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1701026962277/work importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1703269254275/work importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1699364556997/work iniconfig @ file:///home/conda/feedstock_root/build_artifacts/iniconfig_1673103042956/work interegular==0.3.3 ipykernel==6.29.0 ipython==8.20.0 ipython-genutils==0.2.0 ipywidgets==8.1.1 isoduration==20.11.0 jedi==0.19.1 Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1704966972576/work joblib==1.3.2 json5==0.9.14 jsonpointer==2.4 jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1705441631682/work jsonschema-specifications @ file:///tmp/tmpkv1z7p57/src jupyter==1.0.0 jupyter-console==6.6.3 jupyter-events==0.9.0 jupyter-highlight-selected-word==0.2.0 jupyter-lsp==2.2.2 jupyter-nbextensions-configurator==0.6.3 jupyter_client==8.6.0 jupyter_contrib_core==0.4.2 jupyter_contrib_nbextensions==0.7.0 jupyter_core==5.7.1 jupyter_server==2.12.5 jupyter_server_terminals==0.5.1 jupyterlab==4.0.10 jupyterlab-execute-time==3.1.1 jupyterlab-widgets==3.0.9 jupyterlab_pygments==0.3.0 jupyterlab_server==2.25.2 lark==1.1.9 llama_cpp_python==0.2.29 llvmlite==0.41.1 lxml==5.1.0 MarkupSafe @ file:///Users/runner/miniforge3/conda-bld/markupsafe_1695367646585/work matplotlib-inline==0.1.6 mistune==3.0.2 mpmath==1.3.0 multidict @ file:///Users/runner/miniforge3/conda-bld/multidict_1696716121514/work multiprocess @ file:///Users/runner/miniforge3/conda-bld/multiprocess_1695458915095/work nbclient==0.9.0 nbconvert==7.14.2 nbformat==5.9.2 nest-asyncio==1.5.9 networkx==3.2.1 nodeenv @ file:///home/conda/feedstock_root/build_artifacts/nodeenv_1683892983968/work notebook==7.0.6 notebook_shim==0.2.3 numba==0.58.1 numpy @ file:///Users/runner/miniforge3/conda-bld/numpy_1704280780572/work/dist/numpy-1.26.3-cp310-cp310-macosx_11_0_arm64.whl#sha256=f96d0b051b72345dbc317d793b2b34c7c4b7f41b0b791ffc93e820c45ba6a91c -e git+ssh://git@github.com/outlines-dev/outlines.git@8a0bafc#egg=outlines overrides==7.4.0 packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work pandas @ file:///Users/runner/miniforge3/conda-bld/pandas_1702057222502/work pandocfilters==1.5.0 parso==0.8.3 pexpect==4.9.0 pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1694617248815/work platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1701708255999/work pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1693086607691/work pre-commit @ file:///home/conda/feedstock_root/build_artifacts/pre-commit_1702177249902/work prometheus-client==0.19.0 prompt-toolkit==3.0.43 psutil==5.9.7 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==14.0.2 pyarrow-hotfix @ file:///home/conda/feedstock_root/build_artifacts/pyarrow-hotfix_1700596371886/work pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1703248379805/work pydantic_core @ file:///Users/runner/miniforge3/conda-bld/pydantic-core_1703318578880/work Pygments==2.17.2 PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work pytest @ file:///home/conda/feedstock_root/build_artifacts/pytest_1704035161844/work pytest-cov==4.1.0 python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work python-json-logger==2.0.7 pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1693930252784/work PyYAML @ file:///Users/runner/miniforge3/conda-bld/pyyaml_1695373498369/work pyzmq==25.1.2 qtconsole==5.5.1 QtPy==2.4.1 referencing @ file:///home/conda/feedstock_root/build_artifacts/referencing_1704489226496/work regex @ file:///Users/runner/miniforge3/conda-bld/regex_1703393590908/work requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work responses==0.24.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rpds-py @ file:///Users/runner/miniforge3/conda-bld/rpds-py_1705159950823/work safetensors @ file:///Users/runner/miniforge3/conda-bld/safetensors_1695444684081/work SciPy @ file:///Users/runner/miniforge3/conda-bld/scipy-split_1700812700233/work/dist/scipy-1.11.4-cp310-cp310-macosx_11_0_arm64.whl#sha256=375d32c2e30658f658c57cabef9cbbe6df2df8a14f5cb858d49fc66e910be7a5 Send2Trash==1.8.2 six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work sniffio==1.3.0 soupsieve==2.5 stack-data==0.6.3 sympy==1.12 terminado==0.18.0 tinycss2==1.2.1 tokenizers @ file:///Users/runner/miniforge3/conda-bld/tokenizers_1702395225690/work/bindings/python tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work torch==2.1.2 tornado==6.4 tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1691671248568/work traitlets==5.14.1 transformers @ file:///home/conda/feedstock_root/build_artifacts/transformers_1702954525852/work types-python-dateutil==2.8.19.20240106 typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1702176139754/work tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1703878702368/work ukkonen @ file:///Users/runner/miniforge3/conda-bld/ukkonen_1695549417166/work uri-template==1.3.0 urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1699933488691/work virtualenv @ file:///home/conda/feedstock_root/build_artifacts/virtualenv_1701458794382/work wcwidth==0.2.13 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 widgetsnbextension==4.0.9 xxhash @ file:///Users/runner/miniforge3/conda-bld/python-xxhash_1696486346782/work yarl @ file:///Users/runner/miniforge3/conda-bld/yarl_1705508355339/work zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work
</details>


### Context for the issue:

Currently blocking (or at least severely impacting) some work I'm doing on evaluations
@willkurt willkurt added the bug label Feb 5, 2024
@lapp0
Copy link
Collaborator

lapp0 commented Feb 5, 2024

Thank you for your well written and easy to reproduce issue!

Here is the raw output:

{"reasoning": "Marissa's ", "answer": 48000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

Here is the regex pattern which is generated from your schema:

\{[\n ]*"reasoning"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.){,10}"[\n ]*,[\n ]*"answer"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*\}

It looks like the reasoning is respected, but the answer isn't length constrained. It falls into the repetition problem of language models.

Similar to #580 (comment) I recommend using mistralai/Mistral-7B-Instruct-v0.2 which gave me the output {"reasoning": "Marissa's ","answer": 4}

@rlouf I'm wondering if something changed in the sampler that result in these repetition issues? I'm noticing this issue popping up in a few places. @willkurt mentioned he doesn't see this issue with greedy, so maybe the multinomial sampler has trouble generating EOS? Just giving you a heads up that something may have changed. I'm considering writing a logits observation utility to help us with these kinds of issues.

@rlouf
Copy link
Member

rlouf commented Feb 5, 2024

Not that I know of. A tool to be able to quickly look at the logits would be super useful!

@willkurt
Copy link
Contributor Author

willkurt commented Feb 5, 2024

@lapp0 Thanks for looking into this and the quick response!

To be clear, I don't always see the issue with greedy so long as the constraints are set higher. I definitely do see this issue with greedy in some cases, so it's not just a problem with multinomial.

Additionally the problem doesn't seem to be with not honoring the constraints, but that the integer there is unbounded. Technically Python integers have no max value so this is technically not incorrect behavior, so certainly undesirable.

@lapp0
Copy link
Collaborator

lapp0 commented Feb 6, 2024

Seems Mistral-7B-v0.1 is a bit bad at ending the json:

Logits debug log via #616

Selected: '{"' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '{"': 1.000, '{': 0.000, '{': 0.000, '\x04': 0.000, '\x00': 0.000, '': 0.000, '\x01': 0.000, '\x02': 0.000
Selected: 'reason' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'reason': 0.999, 're': 0.001, 'r': 0.000, 'rea': 0.000, 'r': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000
Selected: 'ing' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'ing': 1.000, 'in': 0.000, 'i': 0.000, 'i': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000, '\x01': 0.000
Selected: '":' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '":': 0.994, '":"': 0.005, '"': 0.001, '"': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000, '\x01': 0.000
Selected: '"' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '"': 0.980, '"",': 0.008, '': 0.003, '"(': 0.002, '"\\': 0.001, '"+': 0.001, '"<': 0.001, '"\'': 0.001
Selected: 'Mar' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'Mar': 0.148, 'The': 0.148, 'She': 0.105, 'To': 0.046, 'We': 0.042, 'Let': 0.035, 'First': 0.032, 'If': 0.029
Selected: 'issa' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'issa': 0.996, 'isa': 0.002, 'is': 0.000, 'ris': 0.000, 'iss': 0.000, 'ie': 0.000, 'ri': 0.000, 'isha': 0.000
Selected: "'" for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: "'": 0.437, 'h': 0.268, 'is': 0.170, '’': 0.079, 'sp': 0.012, 'â': 0.006, "\\'": 0.006, '\\': 0.002
Selected: 's' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 's': 1.000, 's': 0.000, 'a': 0.000, 'll': 0.000, 'ss': 0.000, 'S': 0.000, 'sp': 0.000, 'd': 0.000
Selected: '' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '': 0.905, '"': 0.072, ',': 0.004, 'a': 0.003, '\\': 0.002, '\xa0': 0.002, ':': 0.002, "'": 0.001
Selected: '",' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '",': 0.688, '","': 0.233, '"': 0.079, '"': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000, '\x01': 0.000
Selected: '"' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '"': 0.979, '\n': 0.017, '': 0.004, ' ': 0.000, '"': 0.000, '  ': 0.000, '               ': 0.000, '   ': 0.000
Selected: 'answer' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'answer': 0.999, 'ans': 0.000, 'a': 0.000, 'an': 0.000, 'a': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000
Selected: '":' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '":': 0.999, '"': 0.001, '"': 0.000, '\x04': 0.000, '\x00': 0.000, '': 0.000, '\x01': 0.000, '\x02': 0.000
Selected: '' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '': 0.992, ' ': 0.003, '\n': 0.001, '4': 0.001, '2': 0.001, '1': 0.001, '3': 0.000, '6': 0.000
Selected: '4' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '4': 0.304, '2': 0.227, '3': 0.144, '6': 0.098, '1': 0.072, '8': 0.065, '5': 0.053, '0': 0.022
Selected: '8' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '8': 0.243, '0': 0.215, '}': 0.153, '}': 0.140, '\n': 0.035, '2': 0.035, '4': 0.033, '5': 0.030
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.750, '8': 0.049, '9': 0.033, '}': 0.027, '6': 0.019, '1': 0.016, '4': 0.015, '}': 0.015
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.968, '1': 0.005, '}': 0.004, '8': 0.003, '9': 0.003, '6': 0.003, '2': 0.002, '4': 0.002
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.979, '1': 0.004, '}': 0.002, '}': 0.002, '8': 0.002, '2': 0.002, '': 0.002, '3': 0.001
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.986, '1': 0.003, '2': 0.001, '3': 0.001, '8': 0.001, '}': 0.001, '4': 0.001, '9': 0.001
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.973, '1': 0.009, '2': 0.003, '3': 0.002, '4': 0.002, '5': 0.002, '8': 0.002, '9': 0.002
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.979, '1': 0.007, '2': 0.002, '3': 0.002, '}': 0.002, '4': 0.002, '5': 0.001, '7': 0.001
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.981, '1': 0.004, '3': 0.002, '2': 0.002, '4': 0.002, '}': 0.002, '5': 0.001, '7': 0.001
Selected: '0' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '0': 0.985, '1': 0.004, '2': 0.002, '3': 0.001, '4': 0.001, '8': 0.001, '7': 0.001, '5': 0.001
...

after the 4, the chances of } are 15%, then after 48 it's 2.7%, , 480 it's a 0.4% chance and stays extremely low from there.

Compared to mistralai/Mistral-7B-Instruct-v0.2 which has much better numbers:

Selected: '{"' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '{"': 1.000, '{': 0.000, '{': 0.000, '\x04': 0.000, '\x00': 0.000, '': 0.000, '\x01': 0.000, '\x02': 0.000
Selected: 'reason' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'reason': 1.000, 're': 0.000, 'r': 0.000, 'rea': 0.000, 'r': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000
Selected: 'ing' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'ing': 1.000, 'in': 0.000, 'i': 0.000, 'i': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000, '\x01': 0.000
Selected: '":' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '":': 0.999, '":"': 0.001, '"': 0.000, '"': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000, '\x01': 0.000
Selected: '"' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '"': 0.999, '"(': 0.001, '"\'': 0.000, '"[': 0.000, '': 0.000, '"<': 0.000, '":': 0.000, '"+': 0.000
Selected: 'Mar' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'Mar': 0.365, 'First': 0.266, 'Her': 0.169, 'To': 0.082, 'Let': 0.059, 'The': 0.031, 'She': 0.008, 'We': 0.008
Selected: 'issa' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'issa': 1.000, 'iss': 0.000, 'isa': 0.000, 'isse': 0.000, 'iss': 0.000, 'is': 0.000, 'essa': 0.000, 'Iss': 0.000
Selected: "'" for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: "'": 0.897, 'h': 0.079, '’': 0.012, "\\'": 0.007, 'is': 0.001, 'â': 0.001, '\\': 0.001, 'sp': 0.001
Selected: 's' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 's': 1.000, 'ss': 0.000, 's': 0.000, 'sd': 0.000, 'sb': 0.000, 'sp': 0.000, 'S': 0.000, 'a': 0.000
Selected: '' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '': 0.359, '"': 0.164, '速': 0.106, '总': 0.062, ',': 0.023, "'": 0.016, '-': 0.014, '_': 0.012
Selected: '"' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '"': 0.812, '",': 0.168, '","': 0.020, '"': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000, '\x01': 0.000
Selected: ',"' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: ',"': 0.243, ',': 0.210, '': 0.182, ',': 0.056, '       ': 0.041, '           ': 0.028, '               ': 0.026, '            ': 0.026
Selected: 'a' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'a': 0.907, 'answer': 0.070, 'an': 0.023, 'ans': 0.001, 'a': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000
Selected: 'ns' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'ns': 0.596, 'n': 0.332, 'n': 0.072, '\x04': 0.000, '\x00': 0.000, '': 0.000, '\x01': 0.000, '\x02': 0.000
Selected: 'w' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'w': 0.735, 'wer': 0.265, 'we': 0.000, 'w': 0.000, '\x00': 0.000, '\x04': 0.000, '': 0.000, '\x01': 0.000
Selected: 'er' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: 'er': 0.999, 'e': 0.001, 'e': 0.000, '\x04': 0.000, '\x00': 0.000, '': 0.000, '\x01': 0.000, '\x02': 0.000
Selected: '":' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '":': 0.996, '"': 0.004, '"': 0.000, '\x04': 0.000, '\x00': 0.000, '': 0.000, '\x01': 0.000, '\x02': 0.000
Selected: '' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '': 0.988, '4': 0.003, '2': 0.002, '\n': 0.002, '3': 0.001, '6': 0.001, '1': 0.001, ' ': 0.000
Selected: '4' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '4': 0.707, '8': 0.198, '6': 0.027, '2': 0.021, '1': 0.018, '5': 0.017, '3': 0.007, '0': 0.002
Selected: '}' for batch_item=0
	EOS Prob: 0.0(logit = -inf)
	Top 8 Tokens: '}': 0.607, '\n': 0.390, '}': 0.001, '': 0.001, '0': 0.000, ' ': 0.000, '  ': 0.000, '             ': 0.000
Selected: '' for batch_item=0
	EOS Prob: 1.0(logit = -8.42087173461914)
	Top 8 Tokens: '': 1.000, '': 0.000, '\x04': 0.000, '\x01': 0.000, '\x00': 0.000, '': 0.000, '\x02': 0.000, '\x03': 0.000

A 61% chance of } after the initial 4 and a 99% chance of either } or \n.

Here are some thoughts on how to ensure quality generations for Outlines users. I can create issues and tackle some of these if you agree @rlouf

    1. Recommend high-quality models: Recommend models proficient in structured generation and suggest specialized models for niche domains (e.g., Python, SQL, math generation). (easy, strong benefit) Show the Power of Outlines in Documentation by Fixing or Not Doing Math #586
    1. Quality Test Suite: Full integration test suite enabled via pytest --gpu-tests which ensures sane output for our examples. Will also help ensure we don't create "quality regressions". (fairly easy, moderate to strong benefit)
    1. Implement constraints on int "length" (probably easy, modest benefit)
    1. Incorporate repetition penalties: This is a common technique used in most / all inference engines. (somewhat hard, strong benefit)
    1. Introduce a state repetition penalty: As shown in this example, the model struggles to exit the "int value" state, and enter the closing } state. We can bias the logits discourage excessive lingering. (hard, somewhat strong benefit)
    1. Introduce distance-from-termination penalties: after experimenting with the fifth suggestion, we might consider biasing logits towards paths that move automata nearer to their final states.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants