Skip to content

Commit

Permalink
Merge branch 'feat/one-to-one' of https://github.com/zkemail/zk-regex
Browse files Browse the repository at this point in the history
…into feat/one-to-one
  • Loading branch information
SoraSuegami committed May 12, 2024
2 parents 02a3797 + a954f28 commit 40778fd
Showing 1 changed file with 42 additions and 81 deletions.
123 changes: 42 additions & 81 deletions packages/compiler/src/circom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@ use std::path::PathBuf;
fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) -> String {
let n = dfa_graph.states.len();
let mut rev_graph = BTreeMap::<usize, BTreeMap<usize, Vec<u8>>>::new();
let mut to_init_graph = vec![];
let mut init_going_state: Option<usize> = None;

for i in 0..n {
rev_graph.insert(i, BTreeMap::new());
to_init_graph.push(vec![]);
}

let mut accept_nodes = BTreeSet::<usize>::new();
Expand All @@ -29,49 +26,13 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
for (k, v) in &node.edges {
let chars: Vec<u8> = v.iter().cloned().collect();
rev_graph.get_mut(k).unwrap().insert(i, chars.clone());

if i == 0 {
for j in rev_graph.get(&k).unwrap().get(&i).unwrap() {
if *j == 255 {
continue;
}
to_init_graph[*k].push(*j);
}
}
}

if node.r#type == "accept" {
accept_nodes.insert(i);
}
}

if let Some(init_going_state) = init_going_state {
for (going_state, chars) in to_init_graph.iter().enumerate() {
if chars.is_empty() {
continue;
}

if rev_graph
.get_mut(&(going_state as usize))
.unwrap()
.get_mut(&init_going_state)
.is_none()
{
rev_graph
.get_mut(&(going_state as usize))
.unwrap()
.insert(init_going_state, vec![]);
}

rev_graph
.get_mut(&(going_state as usize))
.unwrap()
.get_mut(&init_going_state)
.unwrap()
.extend_from_slice(chars);
}
}

if accept_nodes.is_empty() {
panic!("Accept node must exist");
}
Expand All @@ -95,9 +56,21 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
let mut lines = vec![];

lines.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string());
lines.push("\t\tstates[i+1][0] <== 0;".to_string());
// Check if there is any transition to the 0th state from any state
let mut transition_to_0 = false;
for state in dfa_graph.states.iter() {
if let Some(transitions) = state.edges.get(&0) {
if !transitions.is_empty() {
transition_to_0 = true;
break;
}
}
}
if !transition_to_0 {
lines.push("\t\tstates[i+1][0] <== 0;".to_string());
}

for i in 1..n {
for i in 0..n {
let mut outputs = vec![];

for (prev_i, k) in rev_graph.get(&(i as usize)).unwrap().iter() {
Expand Down Expand Up @@ -273,10 +246,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)

let mut acc_transitions_update = "\t\tacc_transitions[i+1] <== acc_transitions[i]".to_string();
for i in 0..n {
if i == 0 {
continue;
}

acc_transitions_update.push_str(&format!(" + states[i+1][{}]", i));
}
acc_transitions_update.push_str(";");
Expand Down Expand Up @@ -335,34 +304,26 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str)
let mut accept_lines = vec![];

accept_lines.push("".to_string());
accept_lines.push("\tcomponent final_state_result = MultiOR(msg_bytes+1);".to_string());
accept_lines.push("\tfor (var i = 0; i <= msg_bytes; i++) {".to_string());
if accept_nodes_array.len() == 1 {
accept_lines.push(format!(
"\t\tfinal_state_result.in[i] <== states[i][{}];",
accept_nodes_array[0]
));
} else {
let mut accept_outputs = vec![];
let mut accept_outputs_str = String::new();
let mut accept_outputs_str = format!("MultiOR({})([", accept_nodes_array.len());
for accept_node in &accept_nodes_array {
accept_outputs.push(format!("states[i][{}]", accept_node));
}
accept_outputs_str.push_str(&accept_outputs.join(", "));
accept_outputs_str.push_str("])");
accept_lines.push(format!(
"\t\tfinal_state_result.in[i] <== {};",
accept_outputs_str
));
let mut final_state_result = String::new();
final_state_result.push_str(
format!(
"\tsignal final_state_result <== MultiOR({})([",
accept_nodes_array.len()
)
.as_str(),
);
let mut accept_outputs = vec![];
for accept_node in &accept_nodes_array {
accept_outputs.push(format!("states[msg_bytes][{}]", accept_node));
}
accept_lines.push("\t}".to_string());
final_state_result.push_str(&accept_outputs.join(", "));
final_state_result.push_str("]);");
accept_lines.push(final_state_result);
accept_lines.push(
"\tsignal is_acc_valid <== IsEqual()([acc_transitions[num_state_trace-1], msg_bytes]);"
.to_string(),
);
accept_lines.push("\tout <== final_state_result.out * is_acc_valid;".to_string());
accept_lines.push("}".to_string());
accept_lines.push("\tout <== final_state_result * is_acc_valid;".to_string());

final_code.extend(accept_lines);

Expand Down Expand Up @@ -398,13 +359,13 @@ impl RegexAndDFA {
let accepted_state = get_accepted_state(&self.dfa_val).unwrap();
let mut circom: String = "".to_string();
circom += "\n";
circom += "\tsignal is_consecutive[msg_bytes+1][3];\n";
circom += "\tis_consecutive[msg_bytes][2] <== 1;\n";
circom += "\tfor (var i = 0; i < msg_bytes; i++) {\n";
circom += &format!("\t\tis_consecutive[msg_bytes-1-i][0] <== states[num_bytes-i][{}] * (1 - is_consecutive[msg_bytes-i][2]) + is_consecutive[msg_bytes-i][2];\n", accepted_state);
circom += "\t\tis_consecutive[msg_bytes-1-i][1] <== state_changed[msg_bytes-i].out * is_consecutive[msg_bytes-1-i][0];\n";
circom += &format!("\t\tis_consecutive[msg_bytes-1-i][2] <== ORAnd()([(1 - from_zero_enabled[msg_bytes-i+1]), states[num_bytes-i][{}], is_consecutive[msg_bytes-1-i][1]]);\n", accepted_state);
circom += "\t}\n";
// circom += "\tsignal is_consecutive[msg_bytes+1][3];\n";
// circom += "\tis_consecutive[msg_bytes][2] <== 1;\n";
// circom += "\tfor (var i = 0; i < msg_bytes; i++) {\n";
// circom += &format!("\t\tis_consecutive[msg_bytes-1-i][0] <== states[num_bytes-i][{}] * (1 - is_consecutive[msg_bytes-i][2]) + is_consecutive[msg_bytes-i][2];\n", accepted_state);
// circom += "\t\tis_consecutive[msg_bytes-1-i][1] <== state_changed[msg_bytes-i].out * is_consecutive[msg_bytes-1-i][0];\n";
// circom += &format!("\t\tis_consecutive[msg_bytes-1-i][2] <== ORAnd()([(1 - from_zero_enabled[msg_bytes-i+1]), states[num_bytes-i][{}], is_consecutive[msg_bytes-1-i][1]]);\n", accepted_state);
// circom += "\t}\n";

let substr_defs_array = &self.substrs_defs.substr_defs_array;
circom += &format!(
Expand All @@ -414,7 +375,7 @@ impl RegexAndDFA {
for (idx, defs) in substr_defs_array.into_iter().enumerate() {
let num_defs = defs.len();
circom += &format!("\tsignal is_substr{}[msg_bytes];\n", idx);
circom += &format!("\tsignal is_reveal{}[msg_bytes];\n", idx);
// circom += &format!("\tsignal is_reveal{}[msg_bytes];\n", idx);
circom += &format!("\tsignal output reveal{}[msg_bytes];\n", idx);
circom += "\tfor (var i = 0; i < msg_bytes; i++) {\n";
// circom += &format!("\t\tis_substr{}[i][0] <== 0;\n", idx);
Expand Down Expand Up @@ -453,11 +414,11 @@ impl RegexAndDFA {
// // circom += ";\n";
// // }
// }
circom += &format!(
"\t\tis_reveal{}[i] <== is_substr{}[i] * is_consecutive[i][2];\n",
idx, idx
);
circom += &format!("\t\treveal{}[i] <== in[i+1] * is_reveal{}[i];\n", idx, idx);
// circom += &format!(
// "\t\tis_reveal{}[i] <== is_substr{}[i] * is_consecutive[i][2];\n",
// idx, idx
// );
circom += &format!("\t\treveal{}[i] <== in[i+1] * is_substr{}[i];\n", idx, idx);
circom += "\t}\n";
}
circom += "}";
Expand Down

0 comments on commit 40778fd

Please sign in to comment.