Skip to content

Commit

Permalink
Make grammar strict mode the default
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed May 6, 2024
1 parent 20e1960 commit b70e513
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 21 deletions.
20 changes: 14 additions & 6 deletions notebooks/example_date.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,32 @@
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"id": "ffd1e5da",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.80it/s]\n",
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 4.00it/s]\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.87it/s]\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 4.08it/s]\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"from syncode import Syncode\n",
"\n",
"grammar = \"\"\" start: month day \n",
"grammar = \"\"\" start: month \" \" day \n",
" \n",
" day: /[1-9]/ | /[1-2][0-9]/ | /3[0-1]/\n",
" \n",
Expand All @@ -32,12 +40,12 @@
"llm = Syncode(model = model_name, mode='original', max_new_tokens=20)\n",
"\n",
"# Load the Syncode augmented model\n",
"syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar=grammar, parse_output_only=True, max_new_tokens=20)"
"syn_llm = Syncode(model = model_name, grammar=grammar, parse_output_only=True, max_new_tokens=20)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 14,
"id": "f3f9f3b8",
"metadata": {},
"outputs": [
Expand Down
29 changes: 21 additions & 8 deletions notebooks/example_json.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,39 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.95it/s]\n",
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.79it/s]\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 4.03it/s]\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.84it/s]\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading Lark base parser from cache: cache/parsers/json_lalr_100321822995641546967503503968826521991246758102402215142209671679584378119681_parser.pkl\n"
"Creating DFA mask store for CodeGenTokenizerFast and json, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/CodeGenTokenizerFast/grammar_strict_1003218229_50257.pkl.\n",
"Ignore whitespace tokens is True\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 58/58 [00:22<00:00, 2.54it/s]\n"
]
}
],
Expand All @@ -42,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -73,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down
26 changes: 23 additions & 3 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import LogitsProcessor, PreTrainedTokenizer
from syncode.parse_result import RemainderState
from syncode.parsers.incremental_parser import IncrementalParser, ParseResult
from syncode.parsers import create_parser
from syncode.parsers import create_parser, create_base_parser
from syncode.dfa_mask_store import DFAMaskStore
from syncode.parsers.grammars import Grammar

Expand Down Expand Up @@ -48,6 +48,9 @@ def __init__(self,
self.parse_output_only = parse_output_only
self.start_from = None

# Ignore whitespace tokens
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)

# Load dfa mask store
self.dfa_mask_store = DFAMaskStore.load_dfa_mask_store(
grammar=self.grammar,
Expand All @@ -58,16 +61,33 @@ def __init__(self,
)

# Create parsers
self.inc_parsers: Iterator[IncrementalParser] = [create_parser(self.grammar, logger=self.logger, parser=parser) for _ in range(self.batch_size)]
self.inc_parsers: Iterator[IncrementalParser] = [create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace) for _ in range(self.batch_size)]

# For profiling
self.debug = True
self.logger.log_time(f"Time taken for preprocessing: {time.time() - time_start:.2f}s")

def _log_current_status(self, partial_code, r: ParseResult):
self.logger.log_code('Partial code', partial_code)
self.logger.log(repr(r))

def _get_ignore_whitespace(self, grammar):
"""
Check if the grammar allows whitespace tokens to be ignored.
"""
base_parser = create_base_parser(grammar)
terminals = base_parser.terminals
ignore_terminals = base_parser.ignore_tokens

import regex
ignore_whitespace = False
for ig_name in ignore_terminals:
for terminal in terminals:
if terminal.name == ig_name:
if regex.match(terminal.pattern.to_regexp(), ' ') is not None:
ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
return ignore_whitespace

def reset(self, prompt: str):
"""
Resets the decoder state on every new prompt.
Expand Down
4 changes: 2 additions & 2 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from syncode.evaluation.fol_eval import FOLEval


def compile_and_run(model, mode="grammar_mask", quantize=True, device="cuda", num_samples=1, grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", task_id=None, json_eval_type='schema', **kwargs):
def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", num_samples=1, grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", task_id=None, json_eval_type='schema', **kwargs):
sc = Syncode(model, mode=mode, quantize=quantize, device=device, num_samples=num_samples, grammar=grammar, dataset=dataset, num_few_shot=num_few_shot, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, task_id=task_id, json_eval_type= json_eval_type, **kwargs)
sc.infer(task_id=task_id)

Expand Down Expand Up @@ -47,7 +47,7 @@ class Syncode:
def __init__(
self,
model: str,
mode: Literal["original", "grammar_mask", "grammar_strict"] = "grammar_mask",
mode: Literal["original", "grammar_mask", "grammar_strict"] = "grammar_strict",
quantize: bool = True,
device: str = "cuda",
num_samples: int = 1,
Expand Down
8 changes: 6 additions & 2 deletions syncode/parsers/incremental_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ class IncrementalParser:
"""
This is the base class for all incremental parsers.
"""
def __init__(self, base_parser, logger: Optional[common.Logger]=None) -> None:
def __init__(self, base_parser, logger: Optional[common.Logger]=None, ignore_whitespace=False) -> None:
self.cur_pos = 0 # Current cursor position in the lexer tokens list
self.lexer_pos = 0 # Current lexer position in the code
self.dedent_queue: list = []
self._ignore_whitespace = ignore_whitespace

# Initialize the parser
time_start = time.time()
Expand Down Expand Up @@ -171,7 +172,10 @@ def _get_remainder(self, code, lexing_incomplete=False, parse_incomplete=False):
final_terminal = None
if lexing_incomplete: # Lexing is incomplete
current_term_str = code[self.lexer_pos:]
current_term_str = current_term_str.lstrip(' ') # Remove space from the beginning

if self._ignore_whitespace:
current_term_str = current_term_str.lstrip(' ') # Remove space from the beginning

if current_term_str == '':
remainder_state = RemainderState.COMPLETE
else:
Expand Down

0 comments on commit b70e513

Please sign in to comment.