Skip to content

Commit

Permalink
Implement const array sizes in fn params
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed May 28, 2024
1 parent b388e38 commit 147582c
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 28 deletions.
67 changes: 41 additions & 26 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl TypedProgram {
let mut const_sizes = HashMap::new();
let mut consts_unsigned = HashMap::new();
let mut consts_signed = HashMap::new();

for (party, deps) in self.const_deps.iter() {
for (c, ty) in deps {
let Some(party_deps) = consts.get(party) else {
Expand All @@ -86,38 +87,12 @@ impl TypedProgram {
_ => {}
}
if literal.is_of_type(self, ty) {
let bits = literal
.as_bits(self, &const_sizes)
.iter()
.map(|b| *b as usize)
.collect();
env.let_in_current_scope(identifier.clone(), bits);
if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal {
const_sizes.insert(identifier, *size as usize);
}
} else {
return Err(CompilerError::InvalidLiteralType(
literal.clone(),
ty.clone(),
));
}
}
}
let mut input_gates = vec![];
let mut wire = 2;
let Some(fn_def) = self.fn_defs.get(fn_name) else {
return Err(CompilerError::FnNotFound(fn_name.to_string()));
};
for param in fn_def.params.iter() {
let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes);
let mut wires = Vec::with_capacity(type_size);
for _ in 0..type_size {
wires.push(wire);
wire += 1;
}
input_gates.push(type_size);
env.let_in_current_scope(param.name.clone(), wires);
}
fn resolve_const_expr_unsigned(
expr: &ConstExpr,
consts_unsigned: &HashMap<String, u64>,
Expand Down Expand Up @@ -180,6 +155,46 @@ impl TypedProgram {
const_sizes.insert(const_name.clone(), n as usize);
}
}

for (party, deps) in self.const_deps.iter() {
for (c, ty) in deps {
let Some(party_deps) = consts.get(party) else {
return Err(CompilerError::MissingConstant(party.clone(), c.clone()));
};
let Some(literal) = party_deps.get(c) else {
return Err(CompilerError::MissingConstant(party.clone(), c.clone()));
};
let identifier = format!("{party}::{c}");
if literal.is_of_type(self, ty) {
let bits = literal
.as_bits(self, &const_sizes)
.iter()
.map(|b| *b as usize)
.collect();
env.let_in_current_scope(identifier.clone(), bits);
} else {
return Err(CompilerError::InvalidLiteralType(
literal.clone(),
ty.clone(),
));
}
}
}
let mut input_gates = vec![];
let mut wire = 2;
let Some(fn_def) = self.fn_defs.get(fn_name) else {
return Err(CompilerError::FnNotFound(fn_name.to_string()));
};
for param in fn_def.params.iter() {
let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes);
let mut wires = Vec::with_capacity(type_size);
for _ in 0..type_size {
wires.push(wire);
wire += 1;
}
input_gates.push(type_size);
env.let_in_current_scope(param.name.clone(), wires);
}
let mut circuit = CircuitBuilder::new(input_gates, const_sizes.clone());
for (const_name, const_def) in self.const_defs.iter() {
match &const_def.value {
Expand Down
32 changes: 30 additions & 2 deletions src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ impl<'a> Evaluator<'a> {
pub fn set_literal(&mut self, literal: Literal) -> Result<(), EvalError> {
if self.inputs.len() < self.main_fn.params.len() {
let ty = &self.main_fn.params[self.inputs.len()].ty;
if literal.is_of_type(self.program, ty) {
let ty = resolve_const_type(&ty, self.const_sizes);
if literal.is_of_type(self.program, &ty) {
self.inputs.push(vec![]);
self.inputs
.last_mut()
Expand All @@ -210,8 +211,9 @@ impl<'a> Evaluator<'a> {
pub fn parse_literal(&mut self, literal: &str) -> Result<(), EvalError> {
if self.inputs.len() < self.main_fn.params.len() {
let ty = &self.main_fn.params[self.inputs.len()].ty;
let ty = resolve_const_type(&ty, self.const_sizes);
let parsed =
Literal::parse(self.program, ty, literal).map_err(EvalError::LiteralParseError)?;
Literal::parse(self.program, &ty, literal).map_err(EvalError::LiteralParseError)?;
self.set_literal(parsed)?;
Ok(())
} else {
Expand All @@ -220,6 +222,32 @@ impl<'a> Evaluator<'a> {
}
}

fn resolve_const_type(ty: &Type, const_sizes: &HashMap<String, usize>) -> Type {
match ty {
Type::Fn(params, ret_ty) => Type::Fn(
params
.iter()
.map(|ty| resolve_const_type(ty, const_sizes))
.collect(),
Box::new(resolve_const_type(ret_ty, const_sizes)),
),
Type::Array(elem_ty, size) => {
Type::Array(Box::new(resolve_const_type(elem_ty, const_sizes)), *size)
}
Type::ArrayConst(elem_ty, size) => Type::Array(
Box::new(resolve_const_type(elem_ty, const_sizes)),
*const_sizes.get(size).unwrap(),
),
Type::Tuple(elems) => Type::Tuple(
elems
.iter()
.map(|ty| resolve_const_type(ty, const_sizes))
.collect(),
),
ty => ty.clone(),
}
}

/// The encoded result of a circuit evaluation.
#[derive(Debug, Clone)]
pub struct EvalOutput<'a> {
Expand Down
32 changes: 32 additions & 0 deletions tests/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1983,3 +1983,35 @@ pub fn main(x: u16) -> u16 {
);
Ok(())
}

#[test]
fn compile_const_size_in_fn_param() -> Result<(), Error> {
let prg = "
const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST);
pub fn main(array: [u16; MY_CONST]) -> u16 {
array[1]
}
";
let consts = HashMap::from_iter(vec![
(
"PARTY_0".to_string(),
HashMap::from_iter(vec![(
"MY_CONST".to_string(),
Literal::NumUnsigned(1, UnsignedNumType::Usize),
)]),
),
(
"PARTY_1".to_string(),
HashMap::from_iter(vec![(
"MY_CONST".to_string(),
Literal::NumUnsigned(2, UnsignedNumType::Usize),
)]),
),
]);
let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?;
let mut eval = compiled.evaluator();
eval.parse_literal("[7u16, 8u16]").unwrap();
let output = eval.run().map_err(|e| pretty_print(e, prg))?;
assert_eq!(u16::try_from(output).map_err(|e| pretty_print(e, prg))?, 8);
Ok(())
}

0 comments on commit 147582c

Please sign in to comment.