Skip to content

Commit

Permalink
Merge branch 'feat/one-to-one' of https://github.com/zkemail/zk-regex
Browse files Browse the repository at this point in the history
…into feat/one-to-one
  • Loading branch information
SoraSuegami committed May 12, 2024
2 parents cff12e7 + 650d6de commit dddf829
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 140 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ packages/*/build
package-lock.json
yarn.lock

index.node
207 changes: 71 additions & 136 deletions packages/compiler/src/circom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
rev_graph.get_mut(k).unwrap().insert(i, chars.clone());

if i == 0 {
if let Some(index) = chars.iter().position(|&x| x == 94) {
init_going_state = Some(*k);
rev_graph.get_mut(&k).unwrap().get_mut(&i).unwrap()[index] = 255;
}

for j in rev_graph.get(&k).unwrap().get(&i).unwrap() {
if *j == 255 {
continue;
Expand Down Expand Up @@ -83,9 +78,9 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)

let accept_nodes_array: Vec<usize> = accept_nodes.into_iter().collect();

if accept_nodes_array.len() != 1 {
panic!("The size of accept nodes must be one");
}
// if accept_nodes_array.len() != 1 {
// panic!("The size of accept nodes must be one");
// }

let mut eq_i = 0;
let mut lt_i = 0;
Expand All @@ -96,25 +91,17 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
let mut eq_checks = vec![None; 256];
let mut multi_or_checks1 = BTreeMap::<String, usize>::new();
let mut multi_or_checks2 = BTreeMap::<String, usize>::new();
let mut zero_starting_states = vec![];
let mut zero_starting_and_idxes = BTreeMap::<usize, Vec<usize>>::new();

let mut lines = vec![];
// let mut zero_starting_lines = vec![];

lines.push("\tfor (var i = 0; i < num_bytes; i++) {".to_string());
lines.push(format!("\t\tstate_changed[i] = MultiOR({});", n - 1));
lines.push(format!("\t\tstates[i][0] <== 1;"));
lines.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string());
lines.push("\t\tstates[i+1][0] <== 0;".to_string());

for i in 1..n {
let mut outputs = vec![];
zero_starting_and_idxes.insert(i, vec![]);
// let mut state_change_lines = vec![];

for (prev_i, k) in rev_graph.get(&(i as usize)).unwrap().iter() {
let prev_i_num = *prev_i;
if prev_i_num == 0 {
zero_starting_states.push(i);
}
let mut k = k.clone();
k.sort();

Expand Down Expand Up @@ -196,6 +183,7 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
}
}
}

lines.push(format!("\t\tand[{}][i] = AND();", and_i));
lines.push(format!(
"\t\tand[{}][i].a <== states[i][{}];",
Expand All @@ -207,9 +195,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
"\t\tand[{}][i].b <== {}[{}][i].out;",
and_i, eq_outputs[0].0, eq_outputs[0].1
));
if prev_i_num == 0 {
zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i);
}
} else if eq_outputs.len() > 1 {
let eq_outputs_key = serde_json::to_string(&eq_outputs).unwrap();

Expand All @@ -231,9 +216,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
"\t\tand[{}][i].b <== multi_or[{}][i].out;",
and_i, multi_or_i
));
if prev_i_num == 0 {
zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i);
}
multi_or_checks1.insert(eq_outputs_key, multi_or_i);
multi_or_i += 1;
} else {
Expand All @@ -242,29 +224,19 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
"\t\tand[{}][i].b <== multi_or[{}][i].out;",
and_i, multi_or_i
));
if prev_i_num == 0 {
zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i);
}
}
}
}
if prev_i_num != 0 {
outputs.push(and_i);
}

outputs.push(and_i);
and_i += 1;
}

if outputs.len() == 1 {
if zero_starting_states.contains(&i) {
lines.push(format!(
"\t\tstates_tmp[i+1][{}] <== and[{}][i].out;",
i, outputs[0]
));
} else {
lines.push(format!(
"\t\tstates[i+1][{}] <== and[{}][i].out;",
i, outputs[0]
));
}
lines.push(format!(
"\t\tstates[i+1][{}] <== and[{}][i].out;",
i, outputs[0]
));
} else if outputs.len() > 1 {
let outputs_key = serde_json::to_string(&outputs).unwrap();

Expand All @@ -281,87 +253,34 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
multi_or_i, output_i, and_i
));
}
if zero_starting_states.contains(&i) {
lines.push(format!(
"\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;",
i, multi_or_i
));
} else {
lines.push(format!(
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
i, multi_or_i
));
}

lines.push(format!(
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
i, multi_or_i
));
multi_or_checks2.insert(outputs_key, multi_or_i);
multi_or_i += 1;
} else {
if let Some(multi_or_i) = multi_or_checks2.get(&outputs_key) {
if zero_starting_states.contains(&i) {
lines.push(format!(
"\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;",
i, multi_or_i
));
} else {
lines.push(format!(
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
i, multi_or_i
));
}
if let Some(multi_or_i_) = multi_or_checks2.get(&outputs_key) {
lines.push(format!(
"\t\tstates[i+1][{}] <== multi_or[{}][i].out;",
i, multi_or_i_
));
}
}
} else {
if zero_starting_states.contains(&i) {
lines.push(format!("\t\tstates_tmp[i+1][{}] <== 0;", i));
} else {
lines.push(format!("\t\tstates[i+1][{}] <== 0;", i));
}
}

// if zero_starting_states.contains(&i) {
// zero_starting_lines.append(&mut state_change_lines);
// } else {
// lines.append(&mut state_change_lines);
// }
}
// let not_zero_starting_states = (1..n)
// .filter(|i| !zero_starting_states.contains(&i))
// .collect_vec();
lines.push(format!(
"\t\tfrom_zero_enabled[i] <== MultiNOR({})([{}]);",
n - 1,
(1..n)
.map(|i| if zero_starting_states.contains(&i) {
format!("states_tmp[i+1][{}]", i)
} else {
format!("states[i+1][{}]", i)
})
.collect::<Vec<_>>()
.join(", ")
));
for (i, vec) in zero_starting_and_idxes.iter() {
if vec.len() == 0 {

let mut acc_transitions_update = "\t\tacc_transitions[i+1] <== acc_transitions[i]".to_string();
for i in 0..n {
if i == 0 {
continue;
}
lines.push(format!(
"\t\tstates[i+1][{}] <== MultiOR({})([states_tmp[i+1][{}], {}]);",
i,
vec.len() + 1,
i,
vec.iter()
.map(|and_i| format!("from_zero_enabled[i] * and[{}][i].out", and_i))
.collect::<Vec<_>>()
.join(", ")
));
}
for i in 1..n {
lines.push(format!(
"\t\tstate_changed[i].in[{}] <== states[i+1][{}];",
i - 1,
i
));
}

// lines.push("\t\tstates[i+1][0] <== 1 - state_changed[i].out;".to_string());
acc_transitions_update.push_str(&format!(" + states[i+1][{}]", i));
}
acc_transitions_update.push_str(";");
lines.push(acc_transitions_update);

let mut declarations = vec![];
declarations.push("pragma circom 2.1.5;\n".to_string());
Expand All @@ -374,40 +293,37 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
declarations.push(format!("template {}(msg_bytes) {{", template_name));
declarations.push("\tsignal input msg[msg_bytes];".to_string());
declarations.push("\tsignal output out;\n".to_string());
declarations.push("\tvar num_bytes = msg_bytes+1;".to_string());
declarations.push("\tsignal in[num_bytes];".to_string());
declarations.push("\tin[0]<==255;".to_string());
declarations.push("\tvar num_state_trace = msg_bytes+1;".to_string());
declarations.push("\tsignal in[msg_bytes];".to_string());
declarations.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string());
declarations.push("\t\tin[i+1] <== msg[i];".to_string());
declarations.push("\t\tin[i] <== msg[i];".to_string());
declarations.push("\t}\n".to_string());

if eq_i > 0 {
declarations.push(format!("\tcomponent eq[{}][num_bytes];", eq_i));
declarations.push(format!("\tcomponent eq[{}][msg_bytes];", eq_i));
}

if lt_i > 0 {
declarations.push(format!("\tcomponent lt[{}][num_bytes];", lt_i));
declarations.push(format!("\tcomponent lt[{}][msg_bytes];", lt_i));
}

if and_i > 0 {
declarations.push(format!("\tcomponent and[{}][num_bytes];", and_i));
declarations.push(format!("\tcomponent and[{}][msg_bytes];", and_i));
}

if multi_or_i > 0 {
declarations.push(format!("\tcomponent multi_or[{}][num_bytes];", multi_or_i));
declarations.push(format!("\tcomponent multi_or[{}][msg_bytes];", multi_or_i));
}

declarations.push(format!("\tsignal states[num_bytes+1][{}];", n));
declarations.push(format!("\tsignal states_tmp[num_bytes+1][{}];", n));
declarations.push(format!("\tsignal from_zero_enabled[num_bytes+1];"));
declarations.push(format!("\tfrom_zero_enabled[num_bytes] <== 0;"));
declarations.push("\tcomponent state_changed[num_bytes];\n".to_string());
declarations.push(format!("\tsignal states[num_state_trace][{}];", n));
declarations.push(format!("\tsignal acc_transitions[num_state_trace];\n"));

let mut init_code = vec![];
// init_code.push("\tstates[0][0] <== 1;".to_string());
init_code.push("\tstates[0][0] <== 1;".to_string());
init_code.push(format!("\tfor (var i = 1; i < {}; i++) {{", n));
init_code.push("\t\tstates[0][i] <== 0;".to_string());
init_code.push("\t}\n".to_string());
init_code.push("\t}".to_string());
init_code.push("\tacc_transitions[0] <== 0;\n".to_string());

let mut final_code = declarations
.into_iter()
Expand All @@ -416,18 +332,37 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
.collect::<Vec<String>>();
final_code.push("\t}".to_string());

let accept_node = accept_nodes_array[0];
let mut accept_lines = vec![];

accept_lines.push("".to_string());
accept_lines.push("\tcomponent final_state_result = MultiOR(num_bytes+1);".to_string());
accept_lines.push("\tfor (var i = 0; i <= num_bytes; i++) {".to_string());
accept_lines.push(format!(
"\t\tfinal_state_result.in[i] <== states[i][{}];",
accept_node
));
accept_lines.push("\tcomponent final_state_result = MultiOR(msg_bytes+1);".to_string());
accept_lines.push("\tfor (var i = 0; i <= msg_bytes; i++) {".to_string());
if accept_nodes_array.len() == 1 {
accept_lines.push(format!(
"\t\tfinal_state_result.in[i] <== states[i][{}];",
accept_nodes_array[0]
));
} else {
let mut accept_outputs = vec![];
let mut accept_outputs_str = String::new();
let mut accept_outputs_str = format!("MultiOR({})([", accept_nodes_array.len());
for accept_node in &accept_nodes_array {
accept_outputs.push(format!("states[i][{}]", accept_node));
}
accept_outputs_str.push_str(&accept_outputs.join(", "));
accept_outputs_str.push_str("])");
accept_lines.push(format!(
"\t\tfinal_state_result.in[i] <== {};",
accept_outputs_str
));
}
accept_lines.push("\t}".to_string());
accept_lines.push("\tout <== final_state_result.out;".to_string());
accept_lines.push(
"\tsignal is_acc_valid <== IsEqual()([acc_transitions[num_state_trace-1], msg_bytes]);"
.to_string(),
);
accept_lines.push("\tout <== final_state_result.out * is_acc_valid;".to_string());
accept_lines.push("}".to_string());

final_code.extend(accept_lines);

Expand Down
14 changes: 10 additions & 4 deletions packages/compiler/src/regex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ fn parse_dfa_output(output: &str) -> DFAGraphInfo {
eoi_pointing_states.insert(eoi_target);
state.typ = String::from("accept");
state.edges.remove("EOI");
// Set the dst of all edges pointing to eoi_target to this state
for edge in &mut state.edges {
if *edge.1 == eoi_target {
*edge.1 = state.source;
}
}
}
}

let start_state_re = Regex::new(r"START-GROUP\(anchored\)[\s*\w*\=>]*Text => (\d+)").unwrap();
let start_state_re = Regex::new(r"START-GROUP\(unanchored\)[\s*\w*\=>]*Text => (\d+)").unwrap();
let start_state = start_state_re.captures_iter(output).next().unwrap()[1]
.parse::<usize>()
.unwrap();
Expand Down Expand Up @@ -251,17 +257,17 @@ fn add_dfa(net_dfa: &DFAGraph, graph: &DFAGraph) -> DFAGraph {

pub fn regex_and_dfa(decomposed_regex: &DecomposedRegexConfig) -> RegexAndDFA {
let mut config = DFA::config().minimize(true);
config = config.start_kind(StartKind::Anchored);
// config = config.start_kind(StartKind::Unanchored);
config = config.byte_classes(false);
config = config.accelerate(true);
// config = config.accelerate(true);

let mut net_dfa = DFAGraph { states: Vec::new() };
let mut substr_defs_array = Vec::new();

for regex in decomposed_regex.parts.iter() {
let re = DFA::builder()
.configure(config.clone())
.build(&format!(r"^{}$", regex.regex_def))
.build(&format!(r"{}", regex.regex_def))
.unwrap();
let re_str = format!("{:?}", re);
let mut graph = dfa_to_graph(&parse_dfa_output(&re_str));
Expand Down

0 comments on commit dddf829

Please sign in to comment.