Skip to content

Commit dddf829

Browse files
committed
Merge branch 'feat/one-to-one' of https://github.com/zkemail/zk-regex into feat/one-to-one
2 parents cff12e7 + 650d6de commit dddf829

File tree

3 files changed

+82
-140
lines changed

3 files changed

+82
-140
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,4 @@ packages/*/build
111111
package-lock.json
112112
yarn.lock
113113

114+
index.node

packages/compiler/src/circom.rs

Lines changed: 71 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
3131
rev_graph.get_mut(k).unwrap().insert(i, chars.clone());
3232

3333
if i == 0 {
34-
if let Some(index) = chars.iter().position(|&x| x == 94) {
35-
init_going_state = Some(*k);
36-
rev_graph.get_mut(&k).unwrap().get_mut(&i).unwrap()[index] = 255;
37-
}
38-
3934
for j in rev_graph.get(&k).unwrap().get(&i).unwrap() {
4035
if *j == 255 {
4136
continue;
@@ -83,9 +78,9 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
8378

8479
let accept_nodes_array: Vec<usize> = accept_nodes.into_iter().collect();
8580

86-
if accept_nodes_array.len() != 1 {
87-
panic!("The size of accept nodes must be one");
88-
}
81+
// if accept_nodes_array.len() != 1 {
82+
// panic!("The size of accept nodes must be one");
83+
// }
8984

9085
let mut eq_i = 0;
9186
let mut lt_i = 0;
@@ -96,25 +91,17 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
9691
let mut eq_checks = vec![None; 256];
9792
let mut multi_or_checks1 = BTreeMap::<String, usize>::new();
9893
let mut multi_or_checks2 = BTreeMap::<String, usize>::new();
99-
let mut zero_starting_states = vec![];
100-
let mut zero_starting_and_idxes = BTreeMap::<usize, Vec<usize>>::new();
10194

10295
let mut lines = vec![];
103-
// let mut zero_starting_lines = vec![];
10496

105-
lines.push("\tfor (var i = 0; i < num_bytes; i++) {".to_string());
106-
lines.push(format!("\t\tstate_changed[i] = MultiOR({});", n - 1));
107-
lines.push(format!("\t\tstates[i][0] <== 1;"));
97+
lines.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string());
98+
lines.push("\t\tstates[i+1][0] <== 0;".to_string());
99+
108100
for i in 1..n {
109101
let mut outputs = vec![];
110-
zero_starting_and_idxes.insert(i, vec![]);
111-
// let mut state_change_lines = vec![];
112102

113103
for (prev_i, k) in rev_graph.get(&(i as usize)).unwrap().iter() {
114104
let prev_i_num = *prev_i;
115-
if prev_i_num == 0 {
116-
zero_starting_states.push(i);
117-
}
118105
let mut k = k.clone();
119106
k.sort();
120107

@@ -196,6 +183,7 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
196183
}
197184
}
198185
}
186+
199187
lines.push(format!("\t\tand[{}][i] = AND();", and_i));
200188
lines.push(format!(
201189
"\t\tand[{}][i].a <== states[i][{}];",
@@ -207,9 +195,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
207195
"\t\tand[{}][i].b <== {}[{}][i].out;",
208196
and_i, eq_outputs[0].0, eq_outputs[0].1
209197
));
210-
if prev_i_num == 0 {
211-
zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i);
212-
}
213198
} else if eq_outputs.len() > 1 {
214199
let eq_outputs_key = serde_json::to_string(&eq_outputs).unwrap();
215200

@@ -231,9 +216,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
231216
"\t\tand[{}][i].b <== multi_or[{}][i].out;",
232217
and_i, multi_or_i
233218
));
234-
if prev_i_num == 0 {
235-
zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i);
236-
}
237219
multi_or_checks1.insert(eq_outputs_key, multi_or_i);
238220
multi_or_i += 1;
239221
} else {
@@ -242,29 +224,19 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
242224
"\t\tand[{}][i].b <== multi_or[{}][i].out;",
243225
and_i, multi_or_i
244226
));
245-
if prev_i_num == 0 {
246-
zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i);
247-
}
248227
}
249228
}
250229
}
251-
if prev_i_num != 0 {
252-
outputs.push(and_i);
253-
}
230+
231+
outputs.push(and_i);
254232
and_i += 1;
255233
}
234+
256235
if outputs.len() == 1 {
257-
if zero_starting_states.contains(&i) {
258-
lines.push(format!(
259-
"\t\tstates_tmp[i+1][{}] <== and[{}][i].out;",
260-
i, outputs[0]
261-
));
262-
} else {
263-
lines.push(format!(
264-
"\t\tstates[i+1][{}] <== and[{}][i].out;",
265-
i, outputs[0]
266-
));
267-
}
236+
lines.push(format!(
237+
"\t\tstates[i+1][{}] <== and[{}][i].out;",
238+
i, outputs[0]
239+
));
268240
} else if outputs.len() > 1 {
269241
let outputs_key = serde_json::to_string(&outputs).unwrap();
270242

@@ -281,87 +253,34 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
281253
multi_or_i, output_i, and_i
282254
));
283255
}
284-
if zero_starting_states.contains(&i) {
285-
lines.push(format!(
286-
"\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;",
287-
i, multi_or_i
288-
));
289-
} else {
290-
lines.push(format!(
291-
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
292-
i, multi_or_i
293-
));
294-
}
256+
257+
lines.push(format!(
258+
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
259+
i, multi_or_i
260+
));
295261
multi_or_checks2.insert(outputs_key, multi_or_i);
296262
multi_or_i += 1;
297263
} else {
298-
if let Some(multi_or_i) = multi_or_checks2.get(&outputs_key) {
299-
if zero_starting_states.contains(&i) {
300-
lines.push(format!(
301-
"\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;",
302-
i, multi_or_i
303-
));
304-
} else {
305-
lines.push(format!(
306-
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
307-
i, multi_or_i
308-
));
309-
}
264+
if let Some(multi_or_i_) = multi_or_checks2.get(&outputs_key) {
265+
lines.push(format!(
266+
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
267+
i, multi_or_i_
268+
));
310269
}
311270
}
312-
} else {
313-
if zero_starting_states.contains(&i) {
314-
lines.push(format!("\t\tstates_tmp[i+1][{}] <== 0;", i));
315-
} else {
316-
lines.push(format!("\t\tstates[i+1][{}] <== 0;", i));
317-
}
318271
}
319-
320-
// if zero_starting_states.contains(&i) {
321-
// zero_starting_lines.append(&mut state_change_lines);
322-
// } else {
323-
// lines.append(&mut state_change_lines);
324-
// }
325272
}
326-
// let not_zero_starting_states = (1..n)
327-
// .filter(|i| !zero_starting_states.contains(&i))
328-
// .collect_vec();
329-
lines.push(format!(
330-
"\t\tfrom_zero_enabled[i] <== MultiNOR({})([{}]);",
331-
n - 1,
332-
(1..n)
333-
.map(|i| if zero_starting_states.contains(&i) {
334-
format!("states_tmp[i+1][{}]", i)
335-
} else {
336-
format!("states[i+1][{}]", i)
337-
})
338-
.collect::<Vec<_>>()
339-
.join(", ")
340-
));
341-
for (i, vec) in zero_starting_and_idxes.iter() {
342-
if vec.len() == 0 {
273+
274+
let mut acc_transitions_update = "\t\tacc_transitions[i+1] <== acc_transitions[i]".to_string();
275+
for i in 0..n {
276+
if i == 0 {
343277
continue;
344278
}
345-
lines.push(format!(
346-
"\t\tstates[i+1][{}] <== MultiOR({})([states_tmp[i+1][{}], {}]);",
347-
i,
348-
vec.len() + 1,
349-
i,
350-
vec.iter()
351-
.map(|and_i| format!("from_zero_enabled[i] * and[{}][i].out", and_i))
352-
.collect::<Vec<_>>()
353-
.join(", ")
354-
));
355-
}
356-
for i in 1..n {
357-
lines.push(format!(
358-
"\t\tstate_changed[i].in[{}] <== states[i+1][{}];",
359-
i - 1,
360-
i
361-
));
362-
}
363279

364-
// lines.push("\t\tstates[i+1][0] <== 1 - state_changed[i].out;".to_string());
280+
acc_transitions_update.push_str(&format!(" + states[i+1][{}]", i));
281+
}
282+
acc_transitions_update.push_str(";");
283+
lines.push(acc_transitions_update);
365284

366285
let mut declarations = vec![];
367286
declarations.push("pragma circom 2.1.5;\n".to_string());
@@ -374,40 +293,37 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
374293
declarations.push(format!("template {}(msg_bytes) {{", template_name));
375294
declarations.push("\tsignal input msg[msg_bytes];".to_string());
376295
declarations.push("\tsignal output out;\n".to_string());
377-
declarations.push("\tvar num_bytes = msg_bytes+1;".to_string());
378-
declarations.push("\tsignal in[num_bytes];".to_string());
379-
declarations.push("\tin[0]<==255;".to_string());
296+
declarations.push("\tvar num_state_trace = msg_bytes+1;".to_string());
297+
declarations.push("\tsignal in[msg_bytes];".to_string());
380298
declarations.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string());
381-
declarations.push("\t\tin[i+1] <== msg[i];".to_string());
299+
declarations.push("\t\tin[i] <== msg[i];".to_string());
382300
declarations.push("\t}\n".to_string());
383301

384302
if eq_i > 0 {
385-
declarations.push(format!("\tcomponent eq[{}][num_bytes];", eq_i));
303+
declarations.push(format!("\tcomponent eq[{}][msg_bytes];", eq_i));
386304
}
387305

388306
if lt_i > 0 {
389-
declarations.push(format!("\tcomponent lt[{}][num_bytes];", lt_i));
307+
declarations.push(format!("\tcomponent lt[{}][msg_bytes];", lt_i));
390308
}
391309

392310
if and_i > 0 {
393-
declarations.push(format!("\tcomponent and[{}][num_bytes];", and_i));
311+
declarations.push(format!("\tcomponent and[{}][msg_bytes];", and_i));
394312
}
395313

396314
if multi_or_i > 0 {
397-
declarations.push(format!("\tcomponent multi_or[{}][num_bytes];", multi_or_i));
315+
declarations.push(format!("\tcomponent multi_or[{}][msg_bytes];", multi_or_i));
398316
}
399317

400-
declarations.push(format!("\tsignal states[num_bytes+1][{}];", n));
401-
declarations.push(format!("\tsignal states_tmp[num_bytes+1][{}];", n));
402-
declarations.push(format!("\tsignal from_zero_enabled[num_bytes+1];"));
403-
declarations.push(format!("\tfrom_zero_enabled[num_bytes] <== 0;"));
404-
declarations.push("\tcomponent state_changed[num_bytes];\n".to_string());
318+
declarations.push(format!("\tsignal states[num_state_trace][{}];", n));
319+
declarations.push(format!("\tsignal acc_transitions[num_state_trace];\n"));
405320

406321
let mut init_code = vec![];
407-
// init_code.push("\tstates[0][0] <== 1;".to_string());
322+
init_code.push("\tstates[0][0] <== 1;".to_string());
408323
init_code.push(format!("\tfor (var i = 1; i < {}; i++) {{", n));
409324
init_code.push("\t\tstates[0][i] <== 0;".to_string());
410-
init_code.push("\t}\n".to_string());
325+
init_code.push("\t}".to_string());
326+
init_code.push("\tacc_transitions[0] <== 0;\n".to_string());
411327

412328
let mut final_code = declarations
413329
.into_iter()
@@ -416,18 +332,37 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
416332
.collect::<Vec<String>>();
417333
final_code.push("\t}".to_string());
418334

419-
let accept_node = accept_nodes_array[0];
420335
let mut accept_lines = vec![];
421336

422337
accept_lines.push("".to_string());
423-
accept_lines.push("\tcomponent final_state_result = MultiOR(num_bytes+1);".to_string());
424-
accept_lines.push("\tfor (var i = 0; i <= num_bytes; i++) {".to_string());
425-
accept_lines.push(format!(
426-
"\t\tfinal_state_result.in[i] <== states[i][{}];",
427-
accept_node
428-
));
338+
accept_lines.push("\tcomponent final_state_result = MultiOR(msg_bytes+1);".to_string());
339+
accept_lines.push("\tfor (var i = 0; i <= msg_bytes; i++) {".to_string());
340+
if accept_nodes_array.len() == 1 {
341+
accept_lines.push(format!(
342+
"\t\tfinal_state_result.in[i] <== states[i][{}];",
343+
accept_nodes_array[0]
344+
));
345+
} else {
346+
let mut accept_outputs = vec![];
347+
let mut accept_outputs_str = String::new();
348+
let mut accept_outputs_str = format!("MultiOR({})([", accept_nodes_array.len());
349+
for accept_node in &accept_nodes_array {
350+
accept_outputs.push(format!("states[i][{}]", accept_node));
351+
}
352+
accept_outputs_str.push_str(&accept_outputs.join(", "));
353+
accept_outputs_str.push_str("])");
354+
accept_lines.push(format!(
355+
"\t\tfinal_state_result.in[i] <== {};",
356+
accept_outputs_str
357+
));
358+
}
429359
accept_lines.push("\t}".to_string());
430-
accept_lines.push("\tout <== final_state_result.out;".to_string());
360+
accept_lines.push(
361+
"\tsignal is_acc_valid <== IsEqual()([acc_transitions[num_state_trace-1], msg_bytes]);"
362+
.to_string(),
363+
);
364+
accept_lines.push("\tout <== final_state_result.out * is_acc_valid;".to_string());
365+
accept_lines.push("}".to_string());
431366

432367
final_code.extend(accept_lines);
433368

packages/compiler/src/regex.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,16 @@ fn parse_dfa_output(output: &str) -> DFAGraphInfo {
5656
eoi_pointing_states.insert(eoi_target);
5757
state.typ = String::from("accept");
5858
state.edges.remove("EOI");
59+
// Set the dst of all edges pointing to eoi_target to this state
60+
for edge in &mut state.edges {
61+
if *edge.1 == eoi_target {
62+
*edge.1 = state.source;
63+
}
64+
}
5965
}
6066
}
6167

62-
let start_state_re = Regex::new(r"START-GROUP\(anchored\)[\s*\w*\=>]*Text => (\d+)").unwrap();
68+
let start_state_re = Regex::new(r"START-GROUP\(unanchored\)[\s*\w*\=>]*Text => (\d+)").unwrap();
6369
let start_state = start_state_re.captures_iter(output).next().unwrap()[1]
6470
.parse::<usize>()
6571
.unwrap();
@@ -251,17 +257,17 @@ fn add_dfa(net_dfa: &DFAGraph, graph: &DFAGraph) -> DFAGraph {
251257

252258
pub fn regex_and_dfa(decomposed_regex: &DecomposedRegexConfig) -> RegexAndDFA {
253259
let mut config = DFA::config().minimize(true);
254-
config = config.start_kind(StartKind::Anchored);
260+
// config = config.start_kind(StartKind::Unanchored);
255261
config = config.byte_classes(false);
256-
config = config.accelerate(true);
262+
// config = config.accelerate(true);
257263

258264
let mut net_dfa = DFAGraph { states: Vec::new() };
259265
let mut substr_defs_array = Vec::new();
260266

261267
for regex in decomposed_regex.parts.iter() {
262268
let re = DFA::builder()
263269
.configure(config.clone())
264-
.build(&format!(r"^{}$", regex.regex_def))
270+
.build(&format!(r"{}", regex.regex_def))
265271
.unwrap();
266272
let re_str = format!("{:?}", re);
267273
let mut graph = dfa_to_graph(&parse_dfa_output(&re_str));

0 commit comments

Comments
 (0)