diff --git a/.github/workflows/snarkjs.sh b/.github/workflows/snarkjs.sh index 7031f163b..a3bc6ddf1 100755 --- a/.github/workflows/snarkjs.sh +++ b/.github/workflows/snarkjs.sh @@ -9,6 +9,10 @@ fi DIR_PATH=$1 CURVE=$2 +# Init stdlib in .noname/release/src/stdlib instead of downloading +echo "Overriding stdlib in .noname/release/src/stdlib..." +mkdir -p ~/.noname/release/src/stdlib/ && cp -r /app/noname/src/stdlib/* ~/.noname/release/src/stdlib/ + # Ensure the circuit directory exists and is initialized echo "Initializing a new Noname package..." noname new --path circuit_noname diff --git a/examples/fixture/asm/kimchi/generic_builtin_bits.asm b/examples/fixture/asm/kimchi/generic_builtin_bits.asm index 671fcfc5e..87ceb2cce 100644 --- a/examples/fixture/asm/kimchi/generic_builtin_bits.asm +++ b/examples/fixture/asm/kimchi/generic_builtin_bits.asm @@ -2,70 +2,118 @@ @ public inputs: 1 DoubleGeneric<1> -DoubleGeneric<1,0,-1,0,-1> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,1,-1> DoubleGeneric<0,0,-1,1> -DoubleGeneric<1> -DoubleGeneric<1,0,-1> -DoubleGeneric<1,0,-1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> -DoubleGeneric<2,0,-1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,1> DoubleGeneric<1,1,-1> -DoubleGeneric<1,0,-1,0,-1> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> -DoubleGeneric<4,0,-1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,1> DoubleGeneric<1,1,-1> -DoubleGeneric<1,-1> +DoubleGeneric<0,0,-1,1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,0,0,0,-1> -DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,-1> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1> +DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<4,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> -DoubleGeneric<1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<4,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> -(0,0) -> (15,0) -> (28,1) -> (36,1) -(1,0) -> (2,0) -> (4,0) -> (16,0) -> (23,0) -(1,2) -> (2,1) -(2,2) -> (3,0) -(4,2) -> (9,0) -(5,0) -> (6,0) -> (8,0) -> (19,0) -> (24,0) -(5,2) -> (6,1) -(6,2) -> (7,0) -(8,2) -> (9,1) -(9,2) -> (14,0) -(10,0) -> (11,0) -> (13,0) -> (20,0) -> (26,0) -(10,2) -> (11,1) -(11,2) -> (12,0) -(13,2) -> (14,1) -(14,2) -> (15,1) +DoubleGeneric<1,0,0,0,-2> +(0,0) -> (46,1) -> (65,1) -> (66,0) +(1,0) -> (3,0) -> (14,0) -> (25,0) +(2,1) -> (3,1) +(3,2) -> (4,1) -> (8,1) +(4,2) -> (7,0) +(5,0) -> (8,0) -> (10,0) -> (11,0) +(5,1) -> (6,0) +(6,2) -> (7,1) +(8,2) -> (9,0) +(10,2) -> (35,0) -> (36,0) -> (47,0) -> (54,0) -> (55,0) +(11,1) -> (12,0) +(13,1) -> (14,1) +(14,2) -> (15,1) -> (19,1) +(15,2) -> (18,0) +(16,0) -> (19,0) -> (21,0) -> (22,0) (16,1) -> (17,0) -(17,2) -> (18,0) -(20,1) -> (21,0) -(21,2) -> (22,0) -(23,2) -> (25,0) -(24,2) -> (25,1) -(25,2) -> (27,0) -(26,2) -> (27,1) -(27,2) -> (28,0) -(29,0) -> (31,0) -> (34,0) -(30,0) -> (32,0) -(31,2) -> (33,0) -(32,2) -> (33,1) -(33,2) -> (35,0) -(34,2) -> (35,1) -(35,2) -> (36,0) +(17,2) -> (18,1) +(19,2) -> (20,0) +(21,2) -> (38,0) -> (39,0) -> (50,0) -> (57,0) -> (58,0) +(22,1) -> (23,0) +(24,1) -> (25,1) +(25,2) -> (26,1) -> (30,1) +(26,2) -> (29,0) +(27,0) -> (30,0) -> (32,0) -> (33,0) +(27,1) -> (28,0) +(28,2) -> (29,1) +(30,2) -> (31,0) +(32,2) -> (42,0) -> (43,0) -> (51,0) -> (61,0) -> (62,0) +(33,1) -> (34,0) +(35,2) -> (41,0) +(36,1) -> (37,0) +(38,2) -> (41,1) +(39,1) -> (40,0) +(41,2) -> (45,0) +(42,2) -> (45,1) +(43,1) -> (44,0) +(45,2) -> (46,0) +(47,1) -> (48,0) +(48,2) -> (49,0) +(51,1) -> (52,0) +(52,2) -> (53,0) +(54,2) -> (60,0) +(55,1) -> (56,0) +(57,2) -> (60,1) +(58,1) -> (59,0) +(60,2) -> (64,0) +(61,2) -> (64,1) +(62,1) -> (63,0) +(64,2) -> (65,0) diff --git a/examples/fixture/asm/r1cs/generic_builtin_bits.asm b/examples/fixture/asm/r1cs/generic_builtin_bits.asm index c216dc416..0988a4dbe 100644 --- a/examples/fixture/asm/r1cs/generic_builtin_bits.asm +++ b/examples/fixture/asm/r1cs/generic_builtin_bits.asm @@ -1,18 +1,24 @@ @ noname.0.7.0 @ public inputs: 1 -v_3 == (v_2) * (v_2 + -1) -0 == (v_3) * (1) -v_5 == (v_4) * (v_4 + -1) -0 == (v_5) * (1) -v_7 == (v_6) * (v_6 + -1) +1 == (v_3) * (1) +v_5 == (v_4) * (-1 * v_2 + v_3) +-1 * v_6 + 1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_2 + v_3) 0 == (v_7) * (1) -v_2 + 2 * v_4 + 4 * v_6 == (v_1) * (1) -1 == (-1 * v_2 + 1) * (1) -1 == (v_4) * (1) -1 == (-1 * v_6 + 1) * (1) -v_1 == (v_2 + 2 * v_4 + 4 * v_6) * (1) -0 == (v_8) * (1) 1 == (v_9) * (1) -0 == (v_10) * (1) -v_1 == (v_8 + 2 * v_9 + 4 * v_10) * (1) +v_11 == (v_10) * (-1 * v_8 + v_9) +-1 * v_12 + 1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_8 + v_9) +0 == (v_13) * (1) +1 == (v_15) * (1) +v_17 == (v_16) * (-1 * v_14 + v_15) +-1 * v_18 + 1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_14 + v_15) +0 == (v_19) * (1) +v_1 == (v_6 + 2 * v_12 + 4 * v_18) * (1) +1 == (-1 * v_6 + 1) * (1) +1 == (v_12) * (1) +1 == (-1 * v_18 + 1) * (1) +v_1 == (v_6 + 2 * v_12 + 4 * v_18) * (1) +2 == (v_1) * (1) diff --git a/examples/generic_builtin_bits.no b/examples/generic_builtin_bits.no index db4830d07..c15b2b178 100644 --- a/examples/generic_builtin_bits.no +++ b/examples/generic_builtin_bits.no @@ -1,8 +1,7 @@ use std::bits; -// 010 = xx, where xx = 2 fn main(pub xx: Field) { - // var + // calculate on a cell var let bits = bits::to_bits(3, xx); assert(!bits[0]); assert(bits[1]); @@ -11,7 +10,7 @@ fn main(pub xx: Field) { let val = bits::from_bits(bits); assert_eq(val, xx); - // constant + // calculate on a constant let cst_bits = bits::to_bits(3, 2); assert(!cst_bits[0]); assert(cst_bits[1]); diff --git a/src/cli/cmd_build_and_check.rs b/src/cli/cmd_build_and_check.rs index ffa456756..20ff93dba 100644 --- a/src/cli/cmd_build_and_check.rs +++ b/src/cli/cmd_build_and_check.rs @@ -11,14 +11,16 @@ use crate::{ r1cs::{snarkjs::SnarkjsExporter, R1CS}, Backend, BackendField, BackendKind, }, - cli::packages::path_to_package, + cli::packages::{path_to_package, path_to_stdlib}, compiler::{compile, generate_witness, typecheck_next_file, Sources}, inputs::{parse_inputs, JsonInputs}, + stdlib::init_stdlib_dep, type_checker::TypeChecker, }; use super::packages::{ - get_deps_of_package, is_lib, validate_package_and_get_manifest, DependencyGraph, UserRepo, + download_stdlib, get_deps_of_package, is_lib, validate_package_and_get_manifest, + DependencyGraph, UserRepo, }; const COMPILED_DIR: &str = "compiled"; @@ -137,6 +139,26 @@ pub fn cmd_check(args: CmdCheck) -> miette::Result<()> { Ok(()) } +fn add_stdlib( + sources: &mut Sources, + tast: &mut TypeChecker, + node_id: usize, +) -> miette::Result { + let mut node_id = node_id; + + // check if the release folder exists, otherwise download the latest release + // todo: check the latest version and compare it with the current version, to decide if download is needed + let stdlib_dir = path_to_stdlib(); + + if !stdlib_dir.exists() { + download_stdlib()?; + } + + node_id = init_stdlib_dep(sources, tast, node_id, stdlib_dir.as_ref()); + + Ok(node_id) +} + fn produce_all_asts(path: &PathBuf) -> miette::Result<(Sources, TypeChecker)> { // find manifest let manifest = validate_package_and_get_manifest(&path, false)?; @@ -161,6 +183,9 @@ fn produce_all_asts(path: &PathBuf) -> miette::Result<(Sources, Type let mut tast = TypeChecker::new(); + // adding stdlib + add_stdlib(&mut sources, &mut tast, node_id)?; + for dep in dep_graph.from_leaves_to_roots() { let path = path_to_package(&dep); diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 84fb80ee8..6e470dd29 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -15,3 +15,6 @@ pub const NONAME_DIRECTORY: &str = ".noname"; /// The directory under [NONAME_DIRECTORY] containing all package-related files. pub const PACKAGE_DIRECTORY: &str = "packages"; + +/// The directory under [NONAME_DIRECTORY] containing all the latest noname release. +pub const RELEASE_DIRECTORY: &str = "release"; diff --git a/src/cli/packages.rs b/src/cli/packages.rs index 03060c6d2..263c12401 100644 --- a/src/cli/packages.rs +++ b/src/cli/packages.rs @@ -7,9 +7,11 @@ use camino::Utf8PathBuf as PathBuf; use miette::{Context, IntoDiagnostic, Result}; use serde::{Deserialize, Serialize}; +use crate::stdlib::STDLIB_DIRECTORY; + use super::{ manifest::{read_manifest, Manifest}, - NONAME_DIRECTORY, PACKAGE_DIRECTORY, + NONAME_DIRECTORY, PACKAGE_DIRECTORY, RELEASE_DIRECTORY, }; /// A dependency is a Github `user/repo` pair. @@ -241,6 +243,16 @@ pub(crate) fn path_to_package(dep: &UserRepo) -> PathBuf { package_dir.join(&dep.user).join(&dep.repo) } +pub(crate) fn path_to_stdlib() -> PathBuf { + let home_dir: PathBuf = dirs::home_dir() + .expect("could not find home directory of current user") + .try_into() + .expect("invalid UTF8 path"); + let noname_dir = home_dir.join(NONAME_DIRECTORY); + + noname_dir.join(RELEASE_DIRECTORY).join(STDLIB_DIRECTORY) +} + /// download package from github pub fn download_from_github(dep: &UserRepo) -> Result<()> { let url = format!( @@ -264,6 +276,44 @@ pub fn download_from_github(dep: &UserRepo) -> Result<()> { Ok(()) } +pub fn download_stdlib() -> Result<()> { + // Hardcoded repository details and target branch + let repo_owner = "zksecurity"; + let repo_name = "noname"; + let target_branch = "main"; + let repo_url = format!( + "https://github.com/{owner}/{repo}.git", + owner = repo_owner, + repo = repo_name + ); + + let home_dir: PathBuf = dirs::home_dir() + .expect("could not find home directory of current user") + .try_into() + .expect("invalid UTF8 path"); + let noname_dir = home_dir.join(NONAME_DIRECTORY); + let release_dir = noname_dir.join("release"); + + // Clone the repository and checkout the specified branch to the temporary directory + let output = process::Command::new("git") + .arg("clone") + .arg("--branch") + .arg(target_branch) + .arg("--single-branch") + .arg(repo_url) + .arg(release_dir) + .output() + .expect("failed to execute git clone command"); + + if !output.status.success() { + miette::bail!(format!( + "Could not clone branch `{target_branch}` of repository `{repo_owner}/{repo_name}`." + )); + } + + Ok(()) +} + pub fn is_lib(path: &PathBuf) -> bool { path.join("src").join("lib.no").exists() } diff --git a/src/error.rs b/src/error.rs index 19f2e1d9b..c06b9bf35 100644 --- a/src/error.rs +++ b/src/error.rs @@ -130,7 +130,7 @@ pub enum ErrorKind { #[error("invalid array size, expected [_; x] with x in [0,2^32]")] InvalidArraySize, - #[error("only allow a single generic parameter for the size of an array argument")] + #[error("Invalid expression in symbolic size")] InvalidSymbolicSize, #[error("invalid generic parameter, expected single uppercase letter, such as N, M, etc.")] @@ -140,7 +140,7 @@ pub enum ErrorKind { GenericValueExpected(String), #[error("conflict generic values during binding for `{0}`: `{1}` and `{2}`")] - ConflictGenericValue(String, u32, u32), + ConflictGenericValue(String, String, String), #[error("unexpected generic parameter: `{0}`")] UnexpectedGenericParameter(String), @@ -361,4 +361,7 @@ pub enum ErrorKind { #[error("invalid range, the end value can't be smaller than the start value")] InvalidRange, + + #[error("division by zero")] + DivisionByZero, } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index f16a2cbac..f2521c156 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -9,7 +9,8 @@ use crate::{ imports::FnKind, parser::{ types::{ - FnSig, ForLoopArgument, GenericParameters, Range, Stmt, StmtKind, Symbolic, Ty, TyKind, + FnSig, ForLoopArgument, GenericParameters, Ident, Range, Stmt, StmtKind, Symbolic, Ty, + TyKind, }, CustomType, Expr, ExprKind, FunctionDef, Op2, }, @@ -29,19 +30,51 @@ pub struct ExprMonoInfo { /// The generic types shouldn't be presented in this field. pub typ: Option, - // todo: see if we can do constant folding on the expression nodes. - // - it is possible to remove this field, as the constant value can be extracted from folded expression node - /// Numeric value of the expression - /// applicable to BigInt type - pub constant: Option, + /// Propagated constant value + pub constant: Option, } -impl ExprMonoInfo { - pub fn new(expr: Expr, typ: Option, value: Option) -> Self { - if value.is_some() && !matches!(typ, Some(TyKind::Field { constant: true })) { - panic!("value can only be set for BigInt type"); +#[derive(Debug, Clone)] +pub enum PropagatedConstant { + Single(BigUint), + Array(Vec), + Custom(HashMap), +} + +impl PropagatedConstant { + pub fn as_single(&self) -> BigUint { + match self { + PropagatedConstant::Single(v) => v.clone(), + _ => panic!("expected single value"), + } + } + + pub fn as_array(&self) -> Vec { + match self { + PropagatedConstant::Array(v) => v.iter().map(|c| c.as_single()).collect(), + _ => panic!("expected array value"), + } + } + + pub fn as_custom(&self) -> HashMap { + match self { + PropagatedConstant::Custom(v) => { + v.iter().map(|(k, c)| (k.clone(), c.as_single())).collect() + } + _ => panic!("expected custom value"), } + } +} +/// impl From trait for single value +impl From for PropagatedConstant { + fn from(v: BigUint) -> Self { + PropagatedConstant::Single(v) + } +} + +impl ExprMonoInfo { + pub fn new(expr: Expr, typ: Option, value: Option) -> Self { Self { expr, typ, @@ -68,18 +101,18 @@ pub struct MTypeInfo { pub typ: TyKind, /// Store constant value - pub value: Option, + pub constant: Option, /// The span of the variable declaration. pub span: Span, } impl MTypeInfo { - pub fn new(typ: &TyKind, span: Span, value: Option) -> Self { + pub fn new(typ: &TyKind, span: Span, value: Option) -> Self { Self { typ: typ.clone(), span, - value, + constant: value, } } } @@ -191,7 +224,10 @@ impl FnSig { TyKind::Array(ty, size) => TyKind::Array(Box::new(self.resolve_type(ty, ctx)), *size), TyKind::GenericSizedArray(ty, sym) => { let val = sym.eval(&self.generics, &ctx.tast); - TyKind::Array(Box::new(self.resolve_type(ty, ctx)), val) + TyKind::Array( + Box::new(self.resolve_type(ty, ctx)), + val.to_u32().expect("array size exceeded u32"), + ) } _ => typ.clone(), } @@ -216,11 +252,11 @@ impl FnSig { } // const NN: Field _ => { - let cst = observed_arg.constant; + let cst = observed_arg.constant.clone(); if is_generic_parameter(sig_arg.name.value.as_str()) && cst.is_some() { self.generics.assign( &sig_arg.name.value, - cst.unwrap(), + cst.unwrap().as_single(), observed_arg.expr.span, )?; } @@ -269,6 +305,10 @@ where functions_instantiated: HashMap, // new method name as the key, old method name as the value methods_instantiated: HashMap<(FullyQualified, String), String>, + // cache for [PropagatedConstant] values from instantiated methods + cst_method_cache: HashMap<(FullyQualified, String), PropagatedConstant>, + // cache for [PropagatedConstant] values from instantiated functions + cst_fn_cache: HashMap, } impl MastCtx { @@ -278,6 +318,8 @@ impl MastCtx { generic_func_scope: Some(0), functions_instantiated: HashMap::new(), methods_instantiated: HashMap::new(), + cst_method_cache: HashMap::new(), + cst_fn_cache: HashMap::new(), } } @@ -342,9 +384,9 @@ impl MastCtx { impl Symbolic { /// Evaluate symbolic size to an integer. - pub fn eval(&self, gens: &GenericParameters, tast: &TypeChecker) -> u32 { + pub fn eval(&self, gens: &GenericParameters, tast: &TypeChecker) -> BigUint { match self { - Symbolic::Concrete(v) => *v, + Symbolic::Concrete(v) => v.clone(), Symbolic::Constant(var) => { let qualified = FullyQualified::local(var.value.clone()); let cst = tast.const_info(&qualified).expect("constant not found"); @@ -509,7 +551,12 @@ fn monomorphize_expr( }, ); - let cst = None; + // propagate the constant value + let cst = lhs_mono.constant.and_then(|c| match c { + PropagatedConstant::Custom(map) => map.get(rhs).cloned(), + _ => None, + }); + ExprMonoInfo::new(mexpr, typ, cst) } @@ -558,10 +605,19 @@ fn monomorphize_expr( .as_ref() .and_then(|sig| sig.return_type.clone().map(|t| t.kind)); - ExprMonoInfo::new(mexpr, typ, None) + // retrieve the constant value from the cache + let cst = ctx.cst_fn_cache.get(&mono_qualified).cloned(); + + ExprMonoInfo::new(mexpr, typ, cst) } else { // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let (fn_info_mono, typ, cst) = + instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + + // cache the constant value + if let Some(cst) = cst.clone() { + ctx.cst_fn_cache.insert(mono_qualified.clone(), cst); + } let fn_name_mono = &fn_info_mono.sig().name; let mexpr = expr.to_mast( @@ -577,7 +633,7 @@ fn monomorphize_expr( let new_qualified = FullyQualified::new(module, &fn_name_mono.value); ctx.add_monomorphized_fn(old_qualified, new_qualified, fn_info_mono); - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, cst) } } @@ -649,10 +705,23 @@ fn monomorphize_expr( }, ); let typ = resolved_sig.return_type.clone().map(|t| t.kind); - ExprMonoInfo::new(mexpr, typ, None) + + // retrieve the constant value from the cache + let cst = ctx + .cst_method_cache + .get(&(struct_qualified.clone(), method_name.value.clone())) + .cloned(); + + ExprMonoInfo::new(mexpr, typ, cst) } else { // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let (fn_info_mono, typ, cst) = + instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + // cache the constant value + if let Some(cst) = cst.clone() { + ctx.cst_method_cache + .insert((struct_qualified.clone(), method_name.value.clone()), cst); + } let fn_name_mono = &fn_info_mono.sig().name; let mexpr = expr.to_mast( @@ -668,7 +737,7 @@ fn monomorphize_expr( ctx.tast .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, cst) } } @@ -739,9 +808,10 @@ fn monomorphize_expr( Some(v) => { let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); - ExprMonoInfo::new(mexpr, typ, v.to_u32()) + ExprMonoInfo::new(mexpr, typ, Some(PropagatedConstant::from(v))) } - None => { + // keep as is + _ => { let mexpr = expr.to_mast( ctx, &ExprKind::BinaryOp { @@ -775,10 +845,13 @@ fn monomorphize_expr( } ExprKind::BigUInt(inner) => { - let cst: u32 = inner.try_into().expect("biguint too large"); let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(inner.clone())); - ExprMonoInfo::new(mexpr, Some(TyKind::Field { constant: true }), Some(cst)) + ExprMonoInfo::new( + mexpr, + Some(TyKind::Field { constant: true }), + Some(PropagatedConstant::from(inner.clone())), + ) } ExprKind::Bool(inner) => { @@ -794,10 +867,10 @@ fn monomorphize_expr( let res = if is_generic_parameter(&name.value) { let mtype = mono_fn_env.get_type_info(&name.value).unwrap(); - let mexpr = - expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(mtype.value.unwrap()))); + let cst = mtype.constant.clone().unwrap().as_single(); + let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(cst))); - ExprMonoInfo::new(mexpr, Some(mtype.typ.clone()), mtype.value) + ExprMonoInfo::new(mexpr, Some(mtype.typ.clone()), mtype.constant.clone()) } else if is_type(&name.value) { let mtype = TyKind::Custom { module: module.clone(), @@ -817,10 +890,13 @@ fn monomorphize_expr( // if it's a variable, // check if it's a constant first let bigint: BigUint = cst.value[0].into(); - let cst: u32 = bigint.clone().try_into().expect("biguint too large"); - let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(bigint)); + let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(bigint.clone())); - ExprMonoInfo::new(mexpr, Some(TyKind::Field { constant: true }), Some(cst)) + ExprMonoInfo::new( + mexpr, + Some(TyKind::Field { constant: true }), + Some(PropagatedConstant::from(bigint)), + ) } else { // otherwise it's a local variable let mexpr = expr.to_mast( @@ -832,7 +908,7 @@ fn monomorphize_expr( ); let mtype = mono_fn_env.get_type_info(&name.value).unwrap().clone(); - ExprMonoInfo::new(mexpr, Some(mtype.typ), mtype.value) + ExprMonoInfo::new(mexpr, Some(mtype.typ), mtype.constant) }; res @@ -955,24 +1031,46 @@ fn monomorphize_expr( )); } - fields_mono.push((ident, observed_mono.expr.clone())); + fields_mono.push(( + ident, + observed_mono.expr.clone(), + observed_mono.constant.clone(), + )); } let mexpr = expr.to_mast( ctx, &ExprKind::CustomTypeDeclaration { custom: custom.clone(), - fields: fields_mono, + // extract a tuple of first two elements + fields: fields_mono + .iter() + .map(|(a, b, _)| (a.clone(), b.clone())) + .collect(), }, ); + let cst_fields = { + let csts = HashMap::from_iter( + fields_mono + .into_iter() + .filter(|(_, _, cst)| cst.is_some()) + .map(|(ident, _, cst)| (ident, cst.unwrap())), + ); + if csts.is_empty() { + None + } else { + Some(PropagatedConstant::Custom(csts)) + } + }; + ExprMonoInfo::new( mexpr, Some(TyKind::Custom { module: module.clone(), name: name.clone(), }), - None, + cst_fields, ) } ExprKind::RepeatedArrayInit { item, size } => { @@ -989,7 +1087,10 @@ fn monomorphize_expr( ); if let Some(cst) = size_mono.constant { - let arr_typ = TyKind::Array(Box::new(item_typ), cst); + let arr_typ = TyKind::Array( + Box::new(item_typ), + cst.as_single().to_u32().expect("array size too large"), + ); ExprMonoInfo::new(mexpr, Some(arr_typ), None) } else { return Err(error(ErrorKind::InvalidArraySize, expr.span)); @@ -1011,23 +1112,30 @@ pub fn monomorphize_block( mono_fn_env: &mut MonomorphizedFnEnv, stmts: &[Stmt], expected_return: Option<&Ty>, -) -> Result<(Vec, Option)> { +) -> Result<(Vec, Option)> { mono_fn_env.nest(); - let mut return_typ = None; + let mut ret_expr_mono = None; let mut stmts_mono = vec![]; for stmt in stmts { - if let Some((stmt, ret_typ)) = monomorphize_stmt(ctx, mono_fn_env, stmt)? { + if let Some((stmt, expr_mono)) = monomorphize_stmt(ctx, mono_fn_env, stmt)? { stmts_mono.push(stmt); - if ret_typ.is_some() { - return_typ = ret_typ; + // only return stmt can return `ExprMonoInfo` which contains propagated constants + if expr_mono.is_some() { + ret_expr_mono = expr_mono; } } } + let return_typ = if let Some(expr_mono) = ret_expr_mono.clone() { + expr_mono.typ + } else { + None + }; + // check the return if let (Some(expected), Some(observed)) = (expected_return, return_typ.clone()) { if !observed.match_expected(&expected.kind, true) { @@ -1040,7 +1148,7 @@ pub fn monomorphize_block( mono_fn_env.pop(); - Ok((stmts_mono, return_typ)) + Ok((stmts_mono, ret_expr_mono)) } /// Monomorphize a statement. @@ -1048,7 +1156,7 @@ pub fn monomorphize_stmt( ctx: &mut MastCtx, mono_fn_env: &mut MonomorphizedFnEnv, stmt: &Stmt, -) -> Result)>> { +) -> Result)>> { let res = match &stmt.kind { StmtKind::Assign { mutable, lhs, rhs } => { let rhs_mono = monomorphize_expr(ctx, rhs, mono_fn_env)?; @@ -1093,7 +1201,9 @@ pub fn monomorphize_stmt( return Err(error(ErrorKind::InvalidRangeSize, stmt.span)); } - if start_mono.constant.unwrap() > end_mono.constant.unwrap() { + if start_mono.constant.unwrap().as_single() + > end_mono.constant.unwrap().as_single() + { return Err(error(ErrorKind::InvalidRangeSize, stmt.span)); } @@ -1147,11 +1257,11 @@ pub fn monomorphize_stmt( StmtKind::Return(res) => { let expr_mono = monomorphize_expr(ctx, res, mono_fn_env)?; let stmt_mono = Stmt { - kind: StmtKind::Return(Box::new(expr_mono.expr)), + kind: StmtKind::Return(Box::new(expr_mono.expr.clone())), span: stmt.span, }; - Some((stmt_mono, expr_mono.typ)) + Some((stmt_mono, Some(expr_mono))) } StmtKind::Comment(_) => None, }; @@ -1168,7 +1278,7 @@ pub fn instantiate_fn_call( fn_info: FnInfo, args: &[ExprMonoInfo], span: Span, -) -> Result<(FnInfo, Option)> { +) -> Result<(FnInfo, Option, Option)> { ctx.start_monomorphize_func(); let fn_sig = fn_info.sig(); @@ -1192,7 +1302,11 @@ pub fn instantiate_fn_call( let val = fn_sig.generics.get(gen); mono_fn_env.store_type( gen, - &MTypeInfo::new(&TyKind::Field { constant: true }, span, Some(val)), + &MTypeInfo::new( + &TyKind::Field { constant: true }, + span, + Some(PropagatedConstant::from(val)), + ), )?; } @@ -1208,7 +1322,7 @@ pub fn instantiate_fn_call( let typ = mono_info.typ.as_ref().expect("expected a value"); mono_fn_env.store_type( arg_name, - &MTypeInfo::new(typ, mono_info.expr.span, mono_info.constant), + &MTypeInfo::new(typ, mono_info.expr.span, mono_info.constant.clone()), )?; } @@ -1216,32 +1330,40 @@ pub fn instantiate_fn_call( let ret_typed = sig_typed.return_type.clone(); // construct the monomorphized function AST - let func_def = match fn_info.kind { - FnKind::BuiltIn(_, handle) => FnInfo { - kind: FnKind::BuiltIn(sig_typed, handle), - is_hint: fn_info.is_hint, - span: fn_info.span, - }, + let (func_def, mono_info) = match fn_info.kind { + FnKind::BuiltIn(_, handle) => ( + FnInfo { + kind: FnKind::BuiltIn(sig_typed, handle), + ..fn_info + }, + // todo: we will need to propagate the constant value from builtin function as well + None, + ), FnKind::Native(fn_def) => { - let (stmts_typed, _) = + let (stmts_typed, mono_info) = monomorphize_block(ctx, mono_fn_env, &fn_def.body, ret_typed.as_ref())?; - FnInfo { - kind: FnKind::Native(FunctionDef { - sig: sig_typed, - body: stmts_typed, - span: fn_def.span, - is_hint: fn_def.is_hint, - }), - is_hint: fn_info.is_hint, - span: fn_info.span, - } + ( + FnInfo { + kind: FnKind::Native(FunctionDef { + sig: sig_typed, + body: stmts_typed, + span: fn_def.span, + is_hint: fn_def.is_hint, + }), + is_hint: fn_info.is_hint, + span: fn_info.span, + }, + mono_info, + ) } }; ctx.finish_monomorphize_func(); - Ok((func_def, ret_typed.map(|t| t.kind))) + let cst = mono_info.and_then(|c| c.constant); + + Ok((func_def, ret_typed.map(|t| t.kind), cst)) } pub fn error(kind: ErrorKind, span: Span) -> Error { Error::new("mast", kind, span) diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 66fa0291b..2e1b1bce0 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -612,3 +612,55 @@ fn test_nonhint_call_with_unsafe() { ErrorKind::UnexpectedUnsafeAttribute )); } + +#[test] +fn test_no_cst_struct_field_prop() { + let code = r#" + struct Thing { + val: Field, + } + + fn gen(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; + } + + fn main(pub xx: Field) { + let thing = Thing { val: xx }; + + let arr = gen(thing.val); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ArgumentTypeMismatch(..) + )); +} + +#[test] +fn test_mut_cst_struct_field_prop() { + let code = r#" + struct Thing { + val: Field, + } + + fn gen(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; + } + + fn main(pub xx: Field) { + let mut thing = Thing { val: 3 }; + thing.val = xx; + + let arr = gen(thing.val); + assert_eq(arr[0], xx); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ArgumentTypeMismatch(..) + )); +} diff --git a/src/parser/expr.rs b/src/parser/expr.rs index d070290bb..db68f2eba 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -545,9 +545,11 @@ impl Expr { // sanity check if !matches!( self.kind, - ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } + ExprKind::Variable { .. } + | ExprKind::FieldAccess { .. } + | ExprKind::ArrayAccess { .. } ) { - panic!("an array access can only follow a variable"); + panic!("an array access can only follow a variable or another array access"); } // array[idx] diff --git a/src/parser/types.rs b/src/parser/types.rs index 0691f1356..4ce8645c8 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -1,4 +1,5 @@ use educe::Educe; +use num_bigint::BigUint; use std::{ collections::{HashMap, HashSet}, fmt::Display, @@ -7,7 +8,7 @@ use std::{ }; use ark_ff::Field; -use num_traits::ToPrimitive; +use num_traits::FromPrimitive; use serde::{Deserialize, Serialize}; use crate::{ @@ -15,7 +16,6 @@ use crate::{ constants::Span, error::{Error, ErrorKind, Result}, lexer::{Keyword, Token, TokenKind, Tokens}, - mast::ExprMonoInfo, stdlib::builtins::BUILTIN_FN_NAMES, syntax::{is_generic_parameter, is_type}, }; @@ -181,7 +181,7 @@ pub enum ModulePath { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum Symbolic { /// A literal number - Concrete(u32), + Concrete(BigUint), /// Point to a constant variable Constant(Ident), /// Generic parameter @@ -230,7 +230,7 @@ impl Symbolic { /// Parse from an expression node recursively. pub fn parse(node: &Expr) -> Result { match &node.kind { - ExprKind::BigUInt(n) => Ok(Symbolic::Concrete(n.to_u32().unwrap())), + ExprKind::BigUInt(n) => Ok(Symbolic::Concrete(n.clone())), ExprKind::Variable { module: _, name } => { if is_generic_parameter(&name.value) { Ok(Symbolic::Generic(name.clone())) @@ -571,7 +571,11 @@ impl FnSig { // resolve the generic parameter match sym { Symbolic::Generic(ident) => { - self.generics.assign(&ident.value, *observed_size, span)?; + self.generics.assign( + &ident.value, + BigUint::from_u32(*observed_size).unwrap(), + span, + )?; } _ => unreachable!("no operation allowed on symbolic size in function argument"), } @@ -633,7 +637,7 @@ impl FnSig { let generics = generics .iter() - .map(|(name, value)| format!("{}={}", name, value.unwrap())) + .map(|(name, value)| format!("{}={}", name, value.as_ref().unwrap())) .collect::>() .join("#"); @@ -742,7 +746,7 @@ pub struct ResolvedSig { #[derive(Debug, Default, Clone, Serialize, Deserialize)] /// Generic parameters for a function signature pub struct GenericParameters { - pub parameters: HashMap>, + pub parameters: HashMap>, pub resolved_sig: Option, } @@ -758,11 +762,13 @@ impl GenericParameters { } /// Get the value of a generic parameter - pub fn get(&self, name: &str) -> u32 { + pub fn get(&self, name: &str) -> BigUint { self.parameters .get(name) .expect("generic parameter not found") + .as_ref() .expect("generic value not assigned") + .clone() } /// Returns whether the generic parameters are empty @@ -771,7 +777,7 @@ impl GenericParameters { } /// Bind a generic parameter to a value - pub fn assign(&mut self, name: &String, value: u32, span: Span) -> Result<()> { + pub fn assign(&mut self, name: &String, value: BigUint, span: Span) -> Result<()> { let existing = self.parameters.get(name); match existing { Some(Some(v)) => { @@ -781,7 +787,11 @@ impl GenericParameters { Err(Error::new( "mast", - ErrorKind::ConflictGenericValue(name.to_string(), *v, value), + ErrorKind::ConflictGenericValue( + name.to_string(), + v.to_str_radix(10), + value.to_str_radix(10), + ), span, )) } diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 3fe256ee2..cc5898bfc 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -1,13 +1,11 @@ use std::vec; -use ark_ff::One; -use kimchi::o1_utils::FieldHelpers; +use kimchi::{o1_utils::FieldHelpers, turshi::helper::CairoFieldHelpers}; use crate::{ backends::Backend, circuit_writer::{CircuitWriter, VarInfo}, constants::Span, - constraints::boolean, error::Result, parser::types::GenericParameters, var::{ConstOrCell, Value, Var}, @@ -15,8 +13,7 @@ use crate::{ use super::{FnInfoType, Module}; -const TO_BITS_FN: &str = "to_bits(const LEN: Field, val: Field) -> [Bool; LEN]"; -const FROM_BITS_FN: &str = "from_bits(bits: [Bool; LEN]) -> Field"; +const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; pub struct BitsLib {} @@ -24,143 +21,49 @@ impl Module for BitsLib { const MODULE: &'static str = "bits"; fn get_fns() -> Vec<(&'static str, FnInfoType)> { - vec![(TO_BITS_FN, to_bits), (FROM_BITS_FN, from_bits)] + vec![(NTH_BIT_FN, nth_bit)] } } -fn to_bits( +fn nth_bit( compiler: &mut CircuitWriter, - generics: &GenericParameters, + _generics: &GenericParameters, vars: &[VarInfo], span: Span, ) -> Result>> { // should be two input vars assert_eq!(vars.len(), 2); - // but the better practice would be to retrieve the value from the generics - let bitlen = generics.get("LEN") as usize; - - // num should be greater than 0 - assert!(bitlen > 0); - - let modulus_bits: usize = B::Field::modulus_biguint() - .bits() - .try_into() - .expect("modulus is too large"); - - assert!(bitlen <= (modulus_bits - 1)); - - // alternatively, it can be retrieved from the first var, but it is not recommended - // let num_var = &vars[0]; + // these should be type checked already, unless it is called by other low level functions + // eg. builtins + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); - // second var is the value to convert let var_info = &vars[1]; - let var = &var_info.var; - assert_eq!(var.len(), 1); + let nth = &var_info.var; + assert_eq!(nth.len(), 1); - let val = match &var[0] { + let nth: usize = match &nth[0] { + ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { ConstOrCell::Cell(cvar) => cvar.clone(), ConstOrCell::Const(cst) => { - // extract the first bitlen bits - let bits = cst - .to_bits() - .iter() - .take(bitlen) - .copied() - // convert to ConstOrVar - .map(|b| ConstOrCell::Const(B::Field::from(b))) - .collect::>(); - - return Ok(Some(Var::new(bits, span))); + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let bit = cst.to_bits(); + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(bit[nth])), + span, + ))); } }; - // convert value to bits - let mut bits = Vec::with_capacity(bitlen); - let mut e2 = B::Field::one(); - let mut lc: Option = None; - - for i in 0..bitlen { - let bit = compiler - .backend - .new_internal_var(Value::NthBit(val.clone(), i), span); - - // constrain it to be either 0 or 1 - // bits[i] * (bits[i] - 1 ) === 0; - boolean::check(compiler, &ConstOrCell::Cell(bit.clone()), span); - - // lc += bits[i] * e2; - let weighted_bit = compiler.backend.mul_const(&bit, &e2, span); - lc = if i == 0 { - Some(weighted_bit) - } else { - Some(compiler.backend.add(&lc.unwrap(), &weighted_bit, span)) - }; - - bits.push(bit.clone()); - e2 = e2 + e2; - } - - compiler.backend.assert_eq_var(&val, &lc.unwrap(), span); - - let bits_cvars = bits.into_iter().map(ConstOrCell::Cell).collect(); - Ok(Some(Var::new(bits_cvars, span))) -} - -fn from_bits( - compiler: &mut CircuitWriter, - generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // only one input var - assert_eq!(vars.len(), 1); - - let var_info = &vars[0]; - let bitlen = generics.get("LEN") as usize; - - let modulus_bits: usize = B::Field::modulus_biguint() - .bits() - .try_into() - .expect("modulus is too large"); - - assert!(bitlen <= (modulus_bits - 1)); - - let bits_vars: Vec<_> = var_info - .var - .cvars - .iter() - .map(|c| match c { - ConstOrCell::Cell(c) => c.clone(), - ConstOrCell::Const(cst) => { - // use a cell var to represent the const for now - // later we will refactor the backend handle ConstOrCell arguments, so we don't have deal with this everywhere - compiler - .backend - .add_constant(Some("converted constant"), *cst, span) - } - }) - .collect(); - - // this might not be necessary since it should be checked in the type checker - assert_eq!(bitlen, bits_vars.len()); - - let mut e2 = B::Field::one(); - let mut lc: Option = None; - - // accumulate the contribution of each bit - for bit in bits_vars { - let weighted_bit = compiler.backend.mul_const(&bit, &e2, span); - - lc = match lc { - None => Some(weighted_bit), - Some(v) => Some(compiler.backend.add(&v, &weighted_bit, span)), - }; - - e2 = e2 + e2; - } - - let cvar = ConstOrCell::Cell(lc.unwrap()); + let bit = compiler + .backend + .new_internal_var(Value::NthBit(val.clone(), nth), span); - Ok(Some(Var::new_cvar(cvar, span))) + Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 150d31517..e01489059 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -1,6 +1,8 @@ use crate::{ backends::Backend, circuit_writer::{CircuitWriter, VarInfo}, + cli::packages::UserRepo, + compiler::{typecheck_next_file, Sources}, constants::Span, error::Result, imports::FnKind, @@ -9,14 +11,18 @@ use crate::{ types::{FnSig, GenericParameters}, ParserCtx, }, - type_checker::FnInfo, + type_checker::{FnInfo, TypeChecker}, var::Var, }; +use std::path::Path; pub mod bits; pub mod builtins; pub mod crypto; +/// The directory under [NONAME_DIRECTORY] containing the native stdlib. +pub const STDLIB_DIRECTORY: &str = "src/stdlib/native/"; + pub enum AllStdModules { Builtins, Crypto, @@ -79,3 +85,25 @@ trait Module { res } } + +pub fn init_stdlib_dep( + sources: &mut Sources, + tast: &mut TypeChecker, + node_id: usize, + path_prefix: &str, +) -> usize { + // list the stdlib dependency in order + let libs = vec!["bits", "comparator", "multiplexer", "mimc", "int"]; + + let mut node_id = node_id; + + for lib in libs { + let module = UserRepo::new(&format!("std/{}", lib)); + let prefix_stdlib = Path::new(path_prefix); + let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}/lib.no"))).unwrap(); + node_id = + typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); + } + + node_id +} diff --git a/src/stdlib/native/bits/lib.no b/src/stdlib/native/bits/lib.no new file mode 100644 index 000000000..1d7097eb2 --- /dev/null +++ b/src/stdlib/native/bits/lib.no @@ -0,0 +1,77 @@ +/// a hint function (unconstrained) to extracts the `nth` bit from a given `value`. +/// Its current implementation points to `std::bits::nth_bit`. So it has an empty body in definition. +/// +/// # Parameters +/// - `value`: The `Field` value from which to extract the bit. +/// - `nth`: The position of the bit to extract (0-indexed). +/// +/// # Returns +/// - `Field`: The value of the `nth` bit (0 or 1). +/// +hint fn nth_bit(value: Field, const nth: Field) -> Field; + +/// Converts an array of boolean values (`bits`) into a `Field` value. +/// +/// # Parameters +/// - `bits`: An array of `Bool` values representing bits, where each `true` represents `1` and `false` represents `0`. +/// +/// # Returns +/// - `Field`: A `Field` value that represents the integer obtained from the binary representation of `bits`. +/// +/// # Example +/// ``` +/// let bits = [true, false, true]; // Represents the binary value 101 +/// let result = from_bits(bits); +/// `result` should be 5 as 101 in binary equals 5 in decimal. +/// ``` +fn from_bits(bits: [Bool; LEN]) -> Field { + let mut lc1 = 0; + let mut e2 = 1; + let zero = 0; + + for index in 0..LEN { + lc1 = lc1 + if bits[index] { e2 } else { zero }; + e2 = e2 + e2; + } + return lc1; +} + +/// Converts a `Field` value into an array of boolean values (`bits`) representing its binary form. +/// +/// # Parameters +/// - `LEN`: The length of the resulting bit array. Determines how many bits are considered in the conversion. +/// - `value`: The `Field` value to convert into binary representation. +/// +/// # Returns +/// - `[Bool; LEN]`: An array of boolean values where each `true` represents `1` and `false` represents `0`. +/// +/// # Example +/// ``` +/// let value = 5; // Binary representation: 101 +/// let bits = to_bits(3, value); +/// `bits` should be [true, false, true] corresponding to the binary 101. +/// ``` +/// +/// # Panics +/// - The function asserts that `from_bits(bits)` equals `value`, ensuring the conversion is correct. +fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { + let mut bits = [false; LEN]; + let mut lc1 = 0; + let mut e2 = 1; + + // TODO: ITE should allow literals. + let true_val = true; + let false_val = false; + + for index in 0..LEN { + let bit_num = unsafe nth_bit(value, index); + + // constraint the bit values to booleans + bits[index] = if bit_num == 1 { true_val } else { false_val }; + } + + // constraint the accumulative contributions of bits to be equal to the value + assert_eq(from_bits(bits), value); + + return bits; +} diff --git a/src/stdlib/native/comparator/lib.no b/src/stdlib/native/comparator/lib.no new file mode 100644 index 000000000..58ad95593 --- /dev/null +++ b/src/stdlib/native/comparator/lib.no @@ -0,0 +1,54 @@ +use std::bits; + +/// Checks if `lhs` is less than `rhs` by evaluating the carry bit after addition and subtraction. +/// +/// # Parameters +/// - `LEN`: The assumped bit length of both `lhs` and `rhs`. +/// - `lhs`: The left-hand side `Field` value to be compared. +/// - `rhs`: The right-hand side `Field` value to be compared. +/// +/// # Returns +/// - `Bool`: `true` if `lhs` is less than `rhs`, otherwise `false`. +/// +/// # Proof +/// - Adding `pow2` to `lhs` ensures a carry bit is added to the result, creating a bit array of length `LEN + 1`. +/// - If `lhs < rhs`, then `lhs - rhs < 0`, making `(1 << LEN) + lhs - rhs` less than `1 << LEN`, resulting in a carry bit of `0`. +/// - Otherwise, the carry bit will be `1`. +/// +fn less_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { + let carry_bit_len = LEN + 1; + + // Calculate 2^LEN using bit shifts. + let mut pow2 = 1; + for ii in 0..LEN { + pow2 = pow2 + pow2; + } + + // Calculate the adjusted sum to determine the carry bit. + let sum = (pow2 + lhs) - rhs; + let sum_bit = bits::to_bits(carry_bit_len, sum); + + let b1 = false; + let b2 = true; + let res = if sum_bit[LEN] { b1 } else { b2 }; + + return res; +} + +/// Checks if `lhs` is less than or equal to `rhs` using the `less_than` function. +/// +/// # Parameters +/// - `LEN`: The assumped bit length of both `lhs` and `rhs`. +/// - `lhs`: The left-hand side `Field` value to be compared. +/// - `rhs`: The right-hand side `Field` value to be compared. +/// +/// # Returns +/// - `Bool`: `true` if `lhs` is less than or equal to `rhs`, otherwise `false`. +/// +/// # Proof +/// By adding 1 to rhs can increase upper bound by 1 for the lhs. +/// Thus, `lhs < lhs + 1` => `lhs <= rhs`. +/// ``` +fn less_eq_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { + return less_than(LEN, lhs, rhs + 1); +} diff --git a/src/stdlib/native/int/lib.no b/src/stdlib/native/int/lib.no new file mode 100644 index 000000000..4ad93c3b1 --- /dev/null +++ b/src/stdlib/native/int/lib.no @@ -0,0 +1,100 @@ +use std::bits; +use std::comparator; + +// u8 +struct Uint8 { + inner: Field, +} + +fn Uint8.new(val: Field) -> Uint8 { + let bit_len = 8; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint8 { + inner: val + }; +} + +// u16 +struct Uint16 { + inner: Field +} + +fn Uint16.new(val: Field) -> Uint16 { + let bit_len = 16; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint16 { + inner: val + }; +} + +// u32 +struct Uint32 { + inner: Field +} + +fn Uint32.new(val: Field) -> Uint32 { + let bit_len = 32; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint32 { + inner: val + }; +} + +// u64 +struct Uint64 { + inner: Field +} + +fn Uint64.new(val: Field) -> Uint64 { + let bit_len = 64; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint64 { + inner: val + }; +} + +// implement comparator + +fn Uint8.less_than(self, rhs: Uint8) -> Bool { + return comparator::less_than(8, self.inner, rhs.inner); +} + +fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { + return comparator::less_eq_than(8, self.inner, rhs.inner); +} + +fn Uint16.less_than(self, rhs: Uint16) -> Bool { + return comparator::less_than(16, self.inner, rhs.inner); +} + +fn Uint16.less_eq_than(self, rhs: Uint16) -> Bool { + return comparator::less_eq_than(16, self.inner, rhs.inner); +} + +fn Uint32.less_than(self, rhs: Uint32) -> Bool { + return comparator::less_than(32, self.inner, rhs.inner); +} + +fn Uint32.less_eq_than(self, rhs: Uint32) -> Bool { + return comparator::less_eq_than(32, self.inner, rhs.inner); +} + +fn Uint64.less_than(self, rhs: Uint64) -> Bool { + return comparator::less_than(64, self.inner, rhs.inner); +} + +fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { + return comparator::less_eq_than(64, self.inner, rhs.inner); +} \ No newline at end of file diff --git a/src/stdlib/native/mimc/lib.no b/src/stdlib/native/mimc/lib.no new file mode 100644 index 000000000..7681f1e4d --- /dev/null +++ b/src/stdlib/native/mimc/lib.no @@ -0,0 +1,153 @@ +/// MIMC hash function using exponentiation with 7. +/// This function allows a maximum of 91 rounds, with each round using exponentiation of 7. +/// +/// # Parameters +/// - `value`: The input value to be hashed. +/// - `key`: The secret key used in the hash computation. +/// +/// # Returns +/// - `Field`: The resulting hash after `ROUNDS` of MIMC operations. +/// +/// # Constraints +/// - `ROUNDS` must be within the range of constants defined in `csts`. +fn mimc7_cipher(value: Field, key: Field) -> Field { + let rounds = 91; + // Initial value: sum of the key and input value. + let init = key + value; + // Variable for the accumulative result of exponentiation with 7. + let mut exp7 = 0; + + // Predefined constants for each round of the hash function. + let csts = [ + 0, + 20888961410941983456478427210666206549300505294776164667214940546594746570981, + 15265126113435022738560151911929040668591755459209400716467504685752745317193, + 8334177627492981984476504167502758309043212251641796197711684499645635709656, + 1374324219480165500871639364801692115397519265181803854177629327624133579404, + 11442588683664344394633565859260176446561886575962616332903193988751292992472, + 2558901189096558760448896669327086721003508630712968559048179091037845349145, + 11189978595292752354820141775598510151189959177917284797737745690127318076389, + 3262966573163560839685415914157855077211340576201936620532175028036746741754, + 17029914891543225301403832095880481731551830725367286980611178737703889171730, + 4614037031668406927330683909387957156531244689520944789503628527855167665518, + 19647356996769918391113967168615123299113119185942498194367262335168397100658, + 5040699236106090655289931820723926657076483236860546282406111821875672148900, + 2632385916954580941368956176626336146806721642583847728103570779270161510514, + 17691411851977575435597871505860208507285462834710151833948561098560743654671, + 11482807709115676646560379017491661435505951727793345550942389701970904563183, + 8360838254132998143349158726141014535383109403565779450210746881879715734773, + 12663821244032248511491386323242575231591777785787269938928497649288048289525, + 3067001377342968891237590775929219083706800062321980129409398033259904188058, + 8536471869378957766675292398190944925664113548202769136103887479787957959589, + 19825444354178182240559170937204690272111734703605805530888940813160705385792, + 16703465144013840124940690347975638755097486902749048533167980887413919317592, + 13061236261277650370863439564453267964462486225679643020432589226741411380501, + 10864774797625152707517901967943775867717907803542223029967000416969007792571, + 10035653564014594269791753415727486340557376923045841607746250017541686319774, + 3446968588058668564420958894889124905706353937375068998436129414772610003289, + 4653317306466493184743870159523234588955994456998076243468148492375236846006, + 8486711143589723036499933521576871883500223198263343024003617825616410932026, + 250710584458582618659378487568129931785810765264752039738223488321597070280, + 2104159799604932521291371026105311735948154964200596636974609406977292675173, + 16313562605837709339799839901240652934758303521543693857533755376563489378839, + 6032365105133504724925793806318578936233045029919447519826248813478479197288, + 14025118133847866722315446277964222215118620050302054655768867040006542798474, + 7400123822125662712777833064081316757896757785777291653271747396958201309118, + 1744432620323851751204287974553233986555641872755053103823939564833813704825, + 8316378125659383262515151597439205374263247719876250938893842106722210729522, + 6739722627047123650704294650168547689199576889424317598327664349670094847386, + 21211457866117465531949733809706514799713333930924902519246949506964470524162, + 13718112532745211817410303291774369209520657938741992779396229864894885156527, + 5264534817993325015357427094323255342713527811596856940387954546330728068658, + 18884137497114307927425084003812022333609937761793387700010402412840002189451, + 5148596049900083984813839872929010525572543381981952060869301611018636120248, + 19799686398774806587970184652860783461860993790013219899147141137827718662674, + 19240878651604412704364448729659032944342952609050243268894572835672205984837, + 10546185249390392695582524554167530669949955276893453512788278945742408153192, + 5507959600969845538113649209272736011390582494851145043668969080335346810411, + 18177751737739153338153217698774510185696788019377850245260475034576050820091, + 19603444733183990109492724100282114612026332366576932662794133334264283907557, + 10548274686824425401349248282213580046351514091431715597441736281987273193140, + 1823201861560942974198127384034483127920205835821334101215923769688644479957, + 11867589662193422187545516240823411225342068709600734253659804646934346124945, + 18718569356736340558616379408444812528964066420519677106145092918482774343613, + 10530777752259630125564678480897857853807637120039176813174150229243735996839, + 20486583726592018813337145844457018474256372770211860618687961310422228379031, + 12690713110714036569415168795200156516217175005650145422920562694422306200486, + 17386427286863519095301372413760745749282643730629659997153085139065756667205, + 2216432659854733047132347621569505613620980842043977268828076165669557467682, + 6309765381643925252238633914530877025934201680691496500372265330505506717193, + 20806323192073945401862788605803131761175139076694468214027227878952047793390, + 4037040458505567977365391535756875199663510397600316887746139396052445718861, + 19948974083684238245321361840704327952464170097132407924861169241740046562673, + 845322671528508199439318170916419179535949348988022948153107378280175750024, + 16222384601744433420585982239113457177459602187868460608565289920306145389382, + 10232118865851112229330353999139005145127746617219324244541194256766741433339, + 6699067738555349409504843460654299019000594109597429103342076743347235369120, + 6220784880752427143725783746407285094967584864656399181815603544365010379208, + 6129250029437675212264306655559561251995722990149771051304736001195288083309, + 10773245783118750721454994239248013870822765715268323522295722350908043393604, + 4490242021765793917495398271905043433053432245571325177153467194570741607167, + 19596995117319480189066041930051006586888908165330319666010398892494684778526, + 837850695495734270707668553360118467905109360511302468085569220634750561083, + 11803922811376367215191737026157445294481406304781326649717082177394185903907, + 10201298324909697255105265958780781450978049256931478989759448189112393506592, + 13564695482314888817576351063608519127702411536552857463682060761575100923924, + 9262808208636973454201420823766139682381973240743541030659775288508921362724, + 173271062536305557219323722062711383294158572562695717740068656098441040230, + 18120430890549410286417591505529104700901943324772175772035648111937818237369, + 20484495168135072493552514219686101965206843697794133766912991150184337935627, + 19155651295705203459475805213866664350848604323501251939850063308319753686505, + 11971299749478202793661982361798418342615500543489781306376058267926437157297, + 18285310723116790056148596536349375622245669010373674803854111592441823052978, + 7069216248902547653615508023941692395371990416048967468982099270925308100727, + 6465151453746412132599596984628739550147379072443683076388208843341824127379, + 16143532858389170960690347742477978826830511669766530042104134302796355145785, + 19362583304414853660976404410208489566967618125972377176980367224623492419647, + 1702213613534733786921602839210290505213503664731919006932367875629005980493, + 10781825404476535814285389902565833897646945212027592373510689209734812292327, + 4212716923652881254737947578600828255798948993302968210248673545442808456151, + 7594017890037021425366623750593200398174488805473151513558919864633711506220, + 18979889247746272055963929241596362599320706910852082477600815822482192194401, + 13602139229813231349386885113156901793661719180900395818909719758150455500533, + ]; + + // Iterate through each round to compute the hash. + for round in 0..rounds { + // Calculate intermediate values based on the round. + let exp1_else = (key + exp7) + csts[round]; + let exp1 = if round == 0 { init } else { exp1_else }; + let exp2 = exp1 * exp1; + let exp4 = exp2 * exp2; + let exp6 = exp4 * exp2; + let exp7_then = exp6 * exp1; + let exp7_else = exp7_then + key; + // Update exp7 based on whether it's the last round. + exp7 = if round != (rounds - 1) { exp7_then } else { exp7_else }; + } + + // Return the final hash value. + return exp7; +} + +/// MIMC hash function for multiple values. +/// Uses the `mimc7_cipher` function iteratively to hash an array of values. +/// +/// # Parameters +/// - `values`: An array of `Field` values to be hashed. +/// - `key`: The secret key used in the hash computation. +/// +/// # Returns +/// - `Field`: The resulting hash after processing all input values. +/// +fn mimc7_hash(values: [Field; LEN], key: Field) -> Field { + // Initialize with the key. + let mut res = key; + // Iterate over each value in the input array. + for value in values { + // Update the result with the MIMC hash of the value and the current result. + res = res + (value + mimc7_cipher(value, res)); + } + // Return the final accumulated result. + return res; +} diff --git a/src/stdlib/native/multiplexer/lib.no b/src/stdlib/native/multiplexer/lib.no new file mode 100644 index 000000000..54e4b5a11 --- /dev/null +++ b/src/stdlib/native/multiplexer/lib.no @@ -0,0 +1,105 @@ +use std::comparator; + +/// Multiplies two vectors of the same length and returns the accumulated sum (dot product). +/// +/// # Parameters +/// - `lhs`: A vector (array) of `Field` elements. +/// - `rhs`: A vector (array) of `Field` elements. +/// +/// # Returns +/// - `Field`: The accumulated sum resulting from the element-wise multiplication of `lhs` and `rhs`. +/// +/// # Panics +/// - The function assumes that `lhs` and `rhs` have the same length, `LEN`. +/// +/// # Example +/// ``` +/// let lhs = [1, 2, 3]; +/// let rhs = [4, 5, 6]; +/// let result = escalar_product(lhs, rhs); +/// result should be 1*4 + 2*5 + 3*6 = 32 +/// ``` +fn escalar_product(lhs: [Field; LEN], rhs: [Field; LEN]) -> Field { + let mut lc = 0; + for idx in 0..LEN { + lc = lc + (lhs[idx] * rhs[idx]); + } + return lc; +} + +/// Generates a selector array of a given length `LEN` with all zeros except for a one at the specified `target_idx`. +/// +/// # Parameters +/// - `LEN`: The length of the output array. +/// - `target_idx`: The index where the value should be 1. The rest of the array will be filled with zeros. +/// +/// # Returns +/// - `[Field; LEN]`: An array of length `LEN` where all elements are zero except for a single `1` at `target_idx`. +/// +/// # Panics +/// - This function asserts that there is exactly one `1` in the generated array, ensuring `target_idx` is within bounds. +/// +/// # Example +/// ``` +/// let selector = gen_selector_arr(5, 2); +/// `selector` should be [0, 0, 1, 0, 0] +/// ``` +fn gen_selector_arr(const LEN: Field, target_idx: Field) -> [Field; LEN] { + let mut selector = [0; LEN]; + let mut lc = 0; + let one = 1; + let zero = 0; + + for idx in 0..LEN { + selector[idx] = if idx == target_idx { one } else { zero }; + lc = lc + selector[idx]; + } + + // Ensures there is exactly one '1' in the range of LEN. + assert(lc == 1); + + return selector; +} + +/// Selects an element from a 2D array based on a `target_idx` and returns a vector of length `WIDLEN`. +/// +/// # Parameters +/// - `arr`: A 2D array of dimensions `[ARRLEN][WIDLEN]` containing `Field` elements. +/// - `target_idx`: The index that determines which row of `arr` to select. +/// +/// # Returns +/// - `[Field; WIDLEN]`: A vector representing the selected row from `arr`. +/// +/// # Algorithm +/// 1. Generate a selector array using `gen_selector_arr` that has a `1` at `target_idx` and `0`s elsewhere. +/// 2. For each column index `idx` of the 2D array: +/// - Extract the `idx`-th element from each row into a temporary array. +/// - Use `escalar_product` with the temporary array and the selector array to `select` the value corresponding to `target_idx`. +/// 3. Reset the temporary array for the next iteration. +/// 4. Return the vector containing the selected row. +/// +/// # Example +/// ``` +/// let arr = [[1, 2], [3, 4], [5, 6]]; +/// let result = select_element(arr, 1); +/// `result` should be [3, 4] as it selects the second row (index 1). +/// ``` +fn select_element(arr: [[Field; WIDLEN]; ARRLEN], target_idx: Field) -> [Field; WIDLEN] { + let mut result = [0; WIDLEN]; + + let selector_arr = gen_selector_arr(ARRLEN, target_idx); + let mut one_len_arr = [0; ARRLEN]; + + for idx in 0..WIDLEN { + for jdx in 0..ARRLEN { + one_len_arr[jdx] = arr[jdx][idx]; + } + // Only one element in `selector_arr` is `1`, so the result is the element in `one_len_arr` + // at the same index as the `1` in `selector_arr`. + result[idx] = escalar_product(one_len_arr, selector_arr); + + // Reset the temporary array for the next column. + one_len_arr = [0; ARRLEN]; + } + return result; +} diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 3aade9de0..dde1093ef 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -10,6 +10,7 @@ use crate::{ }, compiler::{compile, typecheck_next_file, Sources}, inputs::{parse_inputs, ExtField}, + stdlib::{init_stdlib_dep, STDLIB_DIRECTORY}, type_checker::TypeChecker, }; @@ -36,6 +37,8 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); + let mut node_id = 0; + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( &mut tast, @@ -43,7 +46,7 @@ fn test_file( &mut sources, file_name.to_string(), code.clone(), - 0, + node_id, ) .unwrap(); @@ -98,6 +101,8 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); + let mut node_id = 0; + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( &mut tast, diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3d2702bcb..3036fa55d 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,2 +1,3 @@ mod examples; mod modules; +mod stdlib; diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm new file mode 100644 index 000000000..32f492c19 --- /dev/null +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm @@ -0,0 +1,20 @@ +@ noname.0.7.0 +@ public inputs: 2 + +1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_4 + v_5) +-1 * v_8 + 1 == (v_7) * (1) +v_9 == (v_8) * (-1 * v_4 + v_5) +0 == (v_9) * (1) +1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_10 + v_11) +-1 * v_14 + 1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_10 + v_11) +0 == (v_15) * (1) +1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_16 + v_17) +-1 * v_20 + 1 == (v_19) * (1) +v_21 == (v_20) * (-1 * v_16 + v_17) +0 == (v_21) * (1) +v_2 + -1 * v_3 + 3 == (v_8 + 2 * v_14 + 4 * v_20) * (1) +-1 * v_20 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no new file mode 100644 index 000000000..f916a96a2 --- /dev/null +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no @@ -0,0 +1,8 @@ +use std::comparator; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let bit_len = 2; + let res = comparator::less_eq_than(bit_len, lhs, rhs); + + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_than/less_than.asm b/src/tests/stdlib/comparator/less_than/less_than.asm new file mode 100644 index 000000000..cc615ad37 --- /dev/null +++ b/src/tests/stdlib/comparator/less_than/less_than.asm @@ -0,0 +1,20 @@ +@ noname.0.7.0 +@ public inputs: 2 + +1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_4 + v_5) +-1 * v_8 + 1 == (v_7) * (1) +v_9 == (v_8) * (-1 * v_4 + v_5) +0 == (v_9) * (1) +1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_10 + v_11) +-1 * v_14 + 1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_10 + v_11) +0 == (v_15) * (1) +1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_16 + v_17) +-1 * v_20 + 1 == (v_19) * (1) +v_21 == (v_20) * (-1 * v_16 + v_17) +0 == (v_21) * (1) +v_2 + -1 * v_3 + 4 == (v_8 + 2 * v_14 + 4 * v_20) * (1) +-1 * v_20 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than_main.no b/src/tests/stdlib/comparator/less_than/less_than_main.no new file mode 100644 index 000000000..e39d9cf95 --- /dev/null +++ b/src/tests/stdlib/comparator/less_than/less_than_main.no @@ -0,0 +1,8 @@ +use std::comparator; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let bit_len = 2; + let res = comparator::less_than(bit_len, lhs, rhs); + + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/mod.rs b/src/tests/stdlib/comparator/mod.rs new file mode 100644 index 000000000..ed4ce618c --- /dev/null +++ b/src/tests/stdlib/comparator/mod.rs @@ -0,0 +1,177 @@ +use crate::error::{self, ErrorKind}; + +use super::{test_stdlib, test_stdlib_code}; +use error::Result; +use rstest::rstest; + +// code template +static LESS_THAN_TPL: &str = r#" +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let lhs_u = int::{}.new(lhs); + let rhs_u = int::{}.new(rhs); + + let res = lhs_u.less_than(rhs_u); + + return res; +} +"#; + +static LESS_THAN_EQ_TPL: &str = r#" +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let lhs_u = int::{}.new(lhs); + let rhs_u = int::{}.new(rhs); + + let res = lhs_u.less_eq_than(rhs_u); + + return res; +} +"#; + +#[rstest] +#[case(r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_less_than( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + test_stdlib( + "comparator/less_than/less_than_main.no", + Some("comparator/less_than/less_than.asm"), + public_inputs, + private_inputs, + expected_output, + )?; + + Ok(()) +} + +#[test] +fn test_less_than_witness_failure() -> Result<()> { + let public_inputs = r#"{"lhs": "4"}"#; + let private_inputs = r#"{"rhs": "0"}"#; + + let err = test_stdlib( + "comparator/less_than/less_than_main.no", + None, + public_inputs, + private_inputs, + vec![], + ) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} + +#[rstest] +#[case("Uint8", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint16", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +#[case("Uint32", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint64", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_uint_less_than( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + // Replace placeholders with the given integer type. + let code = LESS_THAN_TPL.replace("{}", int_type); + + // Call the test function with the given inputs and expected output. + test_stdlib_code(&code, None, public_inputs, private_inputs, expected_output)?; + + Ok(()) +} + +#[rstest] +#[case("Uint8", r#"{"lhs": "256"}"#, r#"{"rhs": "0"}"#)] // Uint8 overflow +#[case("Uint16", r#"{"lhs": "65536"}"#, r#"{"rhs": "0"}"#)] // Uint16 overflow +#[case("Uint32", r#"{"lhs": "4294967296"}"#, r#"{"rhs": "0"}"#)] // Uint32 overflow +#[case("Uint64", r#"{"lhs": "18446744073709551616"}"#, r#"{"rhs": "0"}"#)] // Uint64 overflow +fn test_uint_less_than_range_failure( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, +) -> Result<()> { + let code = LESS_THAN_TPL.replace("{}", int_type); + + // Test that the provided inputs result in an error due to overflow. + let err = test_stdlib_code(&code, None, public_inputs, private_inputs, vec!["0"]) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} + +// Test for less than or equal scenarios +#[rstest] +#[case(r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] // True case (lhs < rhs) +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "1"}"#, vec!["1"])] // True case (lhs == rhs) +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] // False case +fn test_less_eq_than( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + test_stdlib( + "comparator/less_eq_than/less_eq_than_main.no", + Some("comparator/less_eq_than/less_eq_than.asm"), + public_inputs, + private_inputs, + expected_output, + )?; + + Ok(()) +} + +// implement the rest for less than eq + +#[rstest] +#[case("Uint8", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint16", r#"{"lhs": "1"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint32", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint64", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_uint_less_eq_than( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + let code = LESS_THAN_EQ_TPL.replace("{}", int_type); + + test_stdlib_code(&code, None, public_inputs, private_inputs, expected_output)?; + + Ok(()) +} + +#[rstest] +#[case("Uint8", r#"{"lhs": "256"}"#, r#"{"rhs": "0"}"#)] // Uint8 overflow +#[case("Uint16", r#"{"lhs": "65536"}"#, r#"{"rhs": "0"}"#)] // Uint16 overflow +#[case("Uint32", r#"{"lhs": "4294967296"}"#, r#"{"rhs": "0"}"#)] // Uint32 overflow +#[case("Uint64", r#"{"lhs": "18446744073709551616"}"#, r#"{"rhs": "0"}"#)] // Uint64 overflow +fn test_uint_less_eq_than_range_failure( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, +) -> Result<()> { + let code = LESS_THAN_EQ_TPL.replace("{}", int_type); + + let err = test_stdlib_code(&code, None, public_inputs, private_inputs, vec!["0"]) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} diff --git a/src/tests/stdlib/mimc/mimc_main.no b/src/tests/stdlib/mimc/mimc_main.no new file mode 100644 index 000000000..1e104d1f8 --- /dev/null +++ b/src/tests/stdlib/mimc/mimc_main.no @@ -0,0 +1,7 @@ +use std::mimc; + +fn main(pub key: Field, val: Field) -> Field { + let res = mimc::mimc7_cipher(val, key); + + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/mimc/mod.rs b/src/tests/stdlib/mimc/mod.rs new file mode 100644 index 000000000..cf11059be --- /dev/null +++ b/src/tests/stdlib/mimc/mod.rs @@ -0,0 +1,207 @@ +use crate::{ + backends::r1cs::R1csBn254Field, + error::{self}, +}; + +use super::test_stdlib; +use ark_ff::Field; +use error::Result; +use num_bigint::BigUint; +use num_traits::Zero; +use rstest::rstest; +use std::str::FromStr; + +/// Parses a decimal string into an Fq field element. +fn fq_from_str(s: &str) -> R1csBn254Field { + R1csBn254Field::from_str(s).unwrap() +} + +/// MiMC7 hash function implementation. +fn mimc7(x_in: R1csBn254Field, k: R1csBn254Field, n_rounds: usize) -> R1csBn254Field { + // Round constants c[91] + let c_strings = [ + "0", + "20888961410941983456478427210666206549300505294776164667214940546594746570981", + "15265126113435022738560151911929040668591755459209400716467504685752745317193", + "8334177627492981984476504167502758309043212251641796197711684499645635709656", + "1374324219480165500871639364801692115397519265181803854177629327624133579404", + "11442588683664344394633565859260176446561886575962616332903193988751292992472", + "2558901189096558760448896669327086721003508630712968559048179091037845349145", + "11189978595292752354820141775598510151189959177917284797737745690127318076389", + "3262966573163560839685415914157855077211340576201936620532175028036746741754", + "17029914891543225301403832095880481731551830725367286980611178737703889171730", + "4614037031668406927330683909387957156531244689520944789503628527855167665518", + "19647356996769918391113967168615123299113119185942498194367262335168397100658", + "5040699236106090655289931820723926657076483236860546282406111821875672148900", + "2632385916954580941368956176626336146806721642583847728103570779270161510514", + "17691411851977575435597871505860208507285462834710151833948561098560743654671", + "11482807709115676646560379017491661435505951727793345550942389701970904563183", + "8360838254132998143349158726141014535383109403565779450210746881879715734773", + "12663821244032248511491386323242575231591777785787269938928497649288048289525", + "3067001377342968891237590775929219083706800062321980129409398033259904188058", + "8536471869378957766675292398190944925664113548202769136103887479787957959589", + "19825444354178182240559170937204690272111734703605805530888940813160705385792", + "16703465144013840124940690347975638755097486902749048533167980887413919317592", + "13061236261277650370863439564453267964462486225679643020432589226741411380501", + "10864774797625152707517901967943775867717907803542223029967000416969007792571", + "10035653564014594269791753415727486340557376923045841607746250017541686319774", + "3446968588058668564420958894889124905706353937375068998436129414772610003289", + "4653317306466493184743870159523234588955994456998076243468148492375236846006", + "8486711143589723036499933521576871883500223198263343024003617825616410932026", + "250710584458582618659378487568129931785810765264752039738223488321597070280", + "2104159799604932521291371026105311735948154964200596636974609406977292675173", + "16313562605837709339799839901240652934758303521543693857533755376563489378839", + "6032365105133504724925793806318578936233045029919447519826248813478479197288", + "14025118133847866722315446277964222215118620050302054655768867040006542798474", + "7400123822125662712777833064081316757896757785777291653271747396958201309118", + "1744432620323851751204287974553233986555641872755053103823939564833813704825", + "8316378125659383262515151597439205374263247719876250938893842106722210729522", + "6739722627047123650704294650168547689199576889424317598327664349670094847386", + "21211457866117465531949733809706514799713333930924902519246949506964470524162", + "13718112532745211817410303291774369209520657938741992779396229864894885156527", + "5264534817993325015357427094323255342713527811596856940387954546330728068658", + "18884137497114307927425084003812022333609937761793387700010402412840002189451", + "5148596049900083984813839872929010525572543381981952060869301611018636120248", + "19799686398774806587970184652860783461860993790013219899147141137827718662674", + "19240878651604412704364448729659032944342952609050243268894572835672205984837", + "10546185249390392695582524554167530669949955276893453512788278945742408153192", + "5507959600969845538113649209272736011390582494851145043668969080335346810411", + "18177751737739153338153217698774510185696788019377850245260475034576050820091", + "19603444733183990109492724100282114612026332366576932662794133334264283907557", + "10548274686824425401349248282213580046351514091431715597441736281987273193140", + "1823201861560942974198127384034483127920205835821334101215923769688644479957", + "11867589662193422187545516240823411225342068709600734253659804646934346124945", + "18718569356736340558616379408444812528964066420519677106145092918482774343613", + "10530777752259630125564678480897857853807637120039176813174150229243735996839", + "20486583726592018813337145844457018474256372770211860618687961310422228379031", + "12690713110714036569415168795200156516217175005650145422920562694422306200486", + "17386427286863519095301372413760745749282643730629659997153085139065756667205", + "2216432659854733047132347621569505613620980842043977268828076165669557467682", + "6309765381643925252238633914530877025934201680691496500372265330505506717193", + "20806323192073945401862788605803131761175139076694468214027227878952047793390", + "4037040458505567977365391535756875199663510397600316887746139396052445718861", + "19948974083684238245321361840704327952464170097132407924861169241740046562673", + "845322671528508199439318170916419179535949348988022948153107378280175750024", + "16222384601744433420585982239113457177459602187868460608565289920306145389382", + "10232118865851112229330353999139005145127746617219324244541194256766741433339", + "6699067738555349409504843460654299019000594109597429103342076743347235369120", + "6220784880752427143725783746407285094967584864656399181815603544365010379208", + "6129250029437675212264306655559561251995722990149771051304736001195288083309", + "10773245783118750721454994239248013870822765715268323522295722350908043393604", + "4490242021765793917495398271905043433053432245571325177153467194570741607167", + "19596995117319480189066041930051006586888908165330319666010398892494684778526", + "837850695495734270707668553360118467905109360511302468085569220634750561083", + "11803922811376367215191737026157445294481406304781326649717082177394185903907", + "10201298324909697255105265958780781450978049256931478989759448189112393506592", + "13564695482314888817576351063608519127702411536552857463682060761575100923924", + "9262808208636973454201420823766139682381973240743541030659775288508921362724", + "173271062536305557219323722062711383294158572562695717740068656098441040230", + "18120430890549410286417591505529104700901943324772175772035648111937818237369", + "20484495168135072493552514219686101965206843697794133766912991150184337935627", + "19155651295705203459475805213866664350848604323501251939850063308319753686505", + "11971299749478202793661982361798418342615500543489781306376058267926437157297", + "18285310723116790056148596536349375622245669010373674803854111592441823052978", + "7069216248902547653615508023941692395371990416048967468982099270925308100727", + "6465151453746412132599596984628739550147379072443683076388208843341824127379", + "16143532858389170960690347742477978826830511669766530042104134302796355145785", + "19362583304414853660976404410208489566967618125972377176980367224623492419647", + "1702213613534733786921602839210290505213503664731919006932367875629005980493", + "10781825404476535814285389902565833897646945212027592373510689209734812292327", + "4212716923652881254737947578600828255798948993302968210248673545442808456151", + "7594017890037021425366623750593200398174488805473151513558919864633711506220", + "18979889247746272055963929241596362599320706910852082477600815822482192194401", + "13602139229813231349386885113156901793661719180900395818909719758150455500533", + ]; + + // Convert constants to field elements + let c: Vec = c_strings + .iter() + .map(|&s| R1csBn254Field::from_str(s).unwrap()) + .collect(); + + let mut t7 = Vec::with_capacity(n_rounds - 1); + let mut out = R1csBn254Field::zero(); + + for i in 0..n_rounds { + let t = if i == 0 { + k + x_in + } else { + k + t7[i - 1] + c[i] + }; + + let t2 = t.square(); // t^2 + let t4 = t2.square(); // t^4 + let t6 = t4 * t2; // t^6 + + if i < n_rounds - 1 { + let t7_i = t6 * t; // t^7 + t7.push(t7_i); + } else { + out = t6 * t + k; // Final output: t^7 + k + } + } + + out +} + +fn multi_mimc7(k: R1csBn254Field, values: Vec, n_rounds: usize) -> R1csBn254Field { + let mut res = k; + for x in values { + res = res + x + mimc7(x, res, n_rounds); + } + res +} + +#[rstest] +#[case(0, 1, 91)] +fn test_mimc(#[case] key: u32, #[case] val: u32, #[case] n_rounds: usize) -> Result<()> { + let public_inputs = format!(r#"{{"key": "{}"}}"#, key); + let private_inputs = format!(r#"{{"val": "{}"}}"#, val); + + let x = fq_from_str(val.to_string().as_str()); + let k = fq_from_str(key.to_string().as_str()); + + let expected_output: BigUint = mimc7(x, k, n_rounds).into(); + + test_stdlib( + "mimc/mimc_main.no", + None, + &public_inputs, + &private_inputs, + vec![&expected_output.to_string()], + )?; + + Ok(()) +} + +#[rstest] +#[case(0, vec![1, 2, 3])] +fn test_multi_mimc(#[case] key: u32, #[case] values: Vec) -> Result<()> { + let k = fq_from_str(key.to_string().as_str()); + let x = values + .iter() + .map(|v| fq_from_str(v.to_string().as_str())) + .collect(); + + let expected_output: BigUint = multi_mimc7(k, x, 91).into(); + + let public_inputs = format!(r#"{{"key": "{}"}}"#, key); + // convert to ["1", "2", ...] + let private_inputs = format!( + r#"{{"values": {:?}}}"#, + values + .iter() + .map(|v| v.to_string()) + .collect::>() + ); + + test_stdlib( + "mimc/multi_mimc_main.no", + None, + &public_inputs, + &private_inputs, + vec![&expected_output.to_string()], + )?; + + Ok(()) +} diff --git a/src/tests/stdlib/mimc/multi_mimc_main.no b/src/tests/stdlib/mimc/multi_mimc_main.no new file mode 100644 index 000000000..73f56d104 --- /dev/null +++ b/src/tests/stdlib/mimc/multi_mimc_main.no @@ -0,0 +1,7 @@ +use std::mimc; + +fn main(pub key: Field, values: [Field; 3]) -> Field { + let res = mimc::mimc7_hash(values, key); + + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs new file mode 100644 index 000000000..3bac61a6f --- /dev/null +++ b/src/tests/stdlib/mod.rs @@ -0,0 +1,116 @@ +mod comparator; +mod mimc; +mod multiplexer; + +use std::{path::Path, str::FromStr}; + +use crate::{ + backends::r1cs::{R1csBn254Field, R1CS}, + circuit_writer::CircuitWriter, + compiler::{typecheck_next_file, Sources}, + error::Result, + inputs::parse_inputs, + mast, + stdlib::{init_stdlib_dep, STDLIB_DIRECTORY}, + type_checker::TypeChecker, + witness::CompiledCircuit, +}; + +fn test_stdlib( + path: &str, + asm_path: Option<&str>, + public_inputs: &str, + private_inputs: &str, + expected_public_output: Vec<&str>, +) -> Result>> { + let root = env!("CARGO_MANIFEST_DIR"); + let prefix_path = Path::new(root).join("src/tests/stdlib"); + + // read noname file + let code = std::fs::read_to_string(prefix_path.clone().join(path)).unwrap(); + + let compiled_circuit = test_stdlib_code( + &code, + asm_path, + public_inputs, + private_inputs, + expected_public_output, + )?; + + Ok(compiled_circuit) +} + +fn test_stdlib_code( + code: &str, + asm_path: Option<&str>, + public_inputs: &str, + private_inputs: &str, + expected_public_output: Vec<&str>, +) -> Result>> { + let r1cs = R1CS::new(); + let root = env!("CARGO_MANIFEST_DIR"); + + // parse inputs + let public_inputs = parse_inputs(public_inputs).unwrap(); + let private_inputs = parse_inputs(private_inputs).unwrap(); + + // compile + let mut sources = Sources::new(); + let mut tast = TypeChecker::new(); + let mut node_id = 0; + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); + + let this_module = None; + let _node_id = typecheck_next_file( + &mut tast, + this_module, + &mut sources, + "test.no".to_string(), + code.to_string(), + node_id, + ) + .unwrap(); + + let mast = mast::monomorphize(tast)?; + let compiled_circuit = CircuitWriter::generate_circuit(mast, r1cs)?; + + // this should check the constraints + let generated_witness = + compiled_circuit.generate_witness(public_inputs.clone(), private_inputs.clone())?; + + let expected_public_output = expected_public_output + .iter() + .map(|x| crate::backends::r1cs::R1csBn254Field::from_str(x).unwrap()) + .collect::>(); + + if generated_witness.outputs != expected_public_output { + eprintln!("obtained by executing the circuit:"); + generated_witness + .outputs + .iter() + .for_each(|x| eprintln!("- {x}")); + eprintln!("passed as output by the verifier:"); + expected_public_output + .iter() + .for_each(|x| eprintln!("- {x}")); + panic!("Obtained output does not match expected output"); + } + + // check the ASM + if asm_path.is_some() && compiled_circuit.circuit.backend.num_constraints() < 100 { + let prefix_asm = Path::new(root).join("src/tests/stdlib/"); + let expected_asm = + std::fs::read_to_string(prefix_asm.clone().join(asm_path.unwrap())).unwrap(); + let obtained_asm = compiled_circuit.asm(&Sources::new(), false); + + if obtained_asm != expected_asm { + eprintln!("obtained:"); + eprintln!("{obtained_asm}"); + eprintln!("expected:"); + eprintln!("{expected_asm}"); + panic!("Obtained ASM does not match expected ASM"); + } + } + + Ok(compiled_circuit) +} diff --git a/src/tests/stdlib/multiplexer/mod.rs b/src/tests/stdlib/multiplexer/mod.rs new file mode 100644 index 000000000..cd921dd1d --- /dev/null +++ b/src/tests/stdlib/multiplexer/mod.rs @@ -0,0 +1,48 @@ +use crate::error::{self}; + +use super::test_stdlib; +use error::Result; +use rstest::rstest; + +#[rstest] +#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "1"}"#, vec!["3", "4", "5"])] +fn test_in_range( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + test_stdlib( + "multiplexer/select_element/main.no", + Some("multiplexer/select_element/main.asm"), + public_inputs, + private_inputs, + expected_output, + )?; + + Ok(()) +} + +// require the select idx to be in range +#[rstest] +#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "3"}"#, vec![])] +fn test_out_range( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + use crate::error::ErrorKind; + + let err = test_stdlib( + "multiplexer/select_element/main.no", + Some("multiplexer/select_element/main.asm"), + public_inputs, + private_inputs, + expected_output, + ) + .err() + .expect("Expected error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} diff --git a/src/tests/stdlib/multiplexer/select_element/main.asm b/src/tests/stdlib/multiplexer/select_element/main.asm new file mode 100644 index 000000000..c2fca003d --- /dev/null +++ b/src/tests/stdlib/multiplexer/select_element/main.asm @@ -0,0 +1,36 @@ +@ noname.0.7.0 +@ public inputs: 12 + +0 == (v_14) * (1) +v_16 == (v_15) * (v_13 + -1 * v_14) +-1 * v_17 + 1 == (v_16) * (1) +v_18 == (v_17) * (v_13 + -1 * v_14) +0 == (v_18) * (1) +1 == (v_19) * (1) +v_21 == (v_20) * (v_13 + -1 * v_19) +-1 * v_22 + 1 == (v_21) * (1) +v_23 == (v_22) * (v_13 + -1 * v_19) +0 == (v_23) * (1) +2 == (v_24) * (1) +v_26 == (v_25) * (v_13 + -1 * v_24) +-1 * v_27 + 1 == (v_26) * (1) +v_28 == (v_27) * (v_13 + -1 * v_24) +0 == (v_28) * (1) +1 == (v_29) * (1) +v_31 == (v_30) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29) +-1 * v_32 + 1 == (v_31) * (1) +v_33 == (v_32) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29) +0 == (v_33) * (1) +1 == (v_32) * (1) +v_34 == (v_4) * (v_17) +v_35 == (v_7) * (v_22) +v_36 == (v_10) * (v_27) +v_37 == (v_5) * (v_17) +v_38 == (v_8) * (v_22) +v_39 == (v_11) * (v_27) +v_40 == (v_6) * (v_17) +v_41 == (v_9) * (v_22) +v_42 == (v_12) * (v_27) +v_34 + v_35 + v_36 == (v_1) * (1) +v_37 + v_38 + v_39 == (v_2) * (1) +v_40 + v_41 + v_42 == (v_3) * (1) diff --git a/src/tests/stdlib/multiplexer/select_element/main.no b/src/tests/stdlib/multiplexer/select_element/main.no new file mode 100644 index 000000000..b06ae625c --- /dev/null +++ b/src/tests/stdlib/multiplexer/select_element/main.no @@ -0,0 +1,6 @@ +use std::multiplexer; + +fn main(pub xx: [[Field; 3]; 3], sel: Field) -> [Field; 3] { + let chosen_elements = multiplexer::select_element(xx, sel); + return chosen_elements; +} \ No newline at end of file diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 75900ec5b..cb69360c9 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use num_traits::ToPrimitive; use serde::{Deserialize, Serialize}; use crate::{ @@ -252,6 +253,7 @@ impl TypeChecker { .compute_type(lhs, typed_fn_env)? .expect("type-checker bug: lhs access on an empty var"); + // todo: check and update the const field type for other cases // lhs can be a local variable or a path to an array let lhs_name = match &lhs.kind { // `name = ` @@ -318,6 +320,31 @@ impl TypeChecker { )); } + // update struct field type + if let ExprKind::FieldAccess { + lhs, + rhs: field_name, + } = &lhs.kind + { + // get variable behind lhs + let lhs_node = self + .compute_type(lhs, typed_fn_env)? + .expect("type-checker bug: lhs access on an empty var"); + + // obtain the qualified name of the struct + let (module, struct_name) = match lhs_node.typ { + TyKind::Custom { module, name } => (module, name), + _ => { + return Err( + self.error(ErrorKind::FieldAccessOnNonCustomStruct, lhs.span) + ) + } + }; + + let qualified = FullyQualified::new(&module, &struct_name); + self.update_struct_field(&qualified, &field_name.value, rhs_typ.typ); + } + None } @@ -566,6 +593,12 @@ impl TypeChecker { expr.span, )); } + + // If the observed type is a Field type, then init that struct field as the observed type. + // This is because the field type can be a constant or not, which needs to be propagated. + if matches!(observed_typ.typ, TyKind::Field { .. }) { + self.update_struct_field(&qualified, &defined.0, observed_typ.typ.clone()); + } } let res = ExprTyInfo::new_anon(TyKind::Custom { @@ -587,7 +620,10 @@ impl TypeChecker { let sym = Symbolic::parse(size)?; let res = if let Symbolic::Concrete(size) = sym { // if sym is a concrete variant, then just return concrete array type - ExprTyInfo::new_anon(TyKind::Array(Box::new(item_node.typ), size)) + ExprTyInfo::new_anon(TyKind::Array( + Box::new(item_node.typ), + size.to_u32().expect("array size too large"), + )) } else { // use generic array as the size node might include generic parameters or constant vars ExprTyInfo::new_anon(TyKind::GenericSizedArray( diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 50308a376..3aef5dcd3 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -150,6 +150,29 @@ impl TypeChecker { .expect("couldn't find the struct for storing the method"); struct_info.methods.remove(method_name); } + + /// Update the type of a struct field. + /// When the assignment is done, we need to update the type of the field. + /// This is only for the case of updating field types to either a constant or a variable. + pub fn update_struct_field( + &mut self, + qualified: &FullyQualified, + field_name: &str, + typ: TyKind, + ) { + let struct_info = self + .structs + .get_mut(qualified) + .expect("couldn't find the struct for storing the method"); + + // update the field type + for field in struct_info.fields.iter_mut() { + if field.0 == field_name { + field.1 = typ; + return; + } + } + } } impl TypeChecker {