Skip to content

Commit

Permalink
Store references to captured variables. (#1554)
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Jul 10, 2024
1 parent a2a7c26 commit fbf2b40
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 7 deletions.
1 change: 1 addition & 0 deletions asm-to-pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ impl<T: FieldElement> VMConverter<T> {
}
.into(),
),
outer_var_references: Default::default(),
}
.into();

Expand Down
5 changes: 5 additions & 0 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,11 @@ pub struct LambdaExpression<E = Expression<NamespacedPolynomialReference>> {
pub kind: FunctionKind,
pub params: Vec<Pattern>,
pub body: Box<E>,
/// The IDs of the variables outside the functions that are referenced,
/// i.e. the environment that is captured by the closure.
/// This is filled in by the expression processor.
#[schemars(skip)]
pub outer_var_references: BTreeSet<u64>,
}

impl<Ref> From<LambdaExpression<Expression<Ref>>> for Expression<Ref> {
Expand Down
1 change: 1 addition & 0 deletions importer/src/path_canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ fn check_expression(
kind: _,
params,
body,
outer_var_references: _,
},
) => {
// Add the local variables, ignore collisions.
Expand Down
21 changes: 17 additions & 4 deletions parser/src/powdr.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,13 @@ PolynomialConstantDefinition: PilStatement = {
}

FunctionDefinition: FunctionDefinition = {
<start:@L> "(" <params:ParameterList> ")" <body:BlockExpression> <end:@R> => FunctionDefinition::Expression(Expression::LambdaExpression(ctx.source_ref(start, end), LambdaExpression{kind: FunctionKind::Pure, params, body})),
<start:@L> "(" <params:ParameterList> ")" <body:BlockExpression> <end:@R>
=> FunctionDefinition::Expression(Expression::LambdaExpression(ctx.source_ref(start, end), LambdaExpression{
kind: FunctionKind::Pure,
params,
body,
outer_var_references: Default::default()
})),
<start:@L> "=" <array:ArrayLiteralExpression> <end:@R> => FunctionDefinition::Array(array),
}

Expand All @@ -185,7 +191,12 @@ PolynomialCommitDeclaration: PilStatement = {
ctx.source_ref(start, end),
stage,
vec![name],
Some(FunctionDefinition::Expression(Expression::LambdaExpression(ctx.source_ref(start, end), LambdaExpression{kind: FunctionKind::Query, params, body})))
Some(FunctionDefinition::Expression(Expression::LambdaExpression(ctx.source_ref(start, end), LambdaExpression{
kind: FunctionKind::Query,
params,
body,
outer_var_references: Default::default()
})))
)
}

Expand Down Expand Up @@ -448,8 +459,10 @@ BoxedExpression: Box<Expression> = {
}

LambdaExpression: Box<Expression> = {
<start:@L> <kind:FunctionKind> "||" <body:BoxedExpression> <end:@R> => ctx.to_expr_with_source_ref(LambdaExpression{kind, params: vec![], body}, start, end),
<start:@L> <kind:FunctionKind> "|" <params:ParameterList> "|" <body:BoxedExpression> <end:@R> => ctx.to_expr_with_source_ref(LambdaExpression{kind, params, body}, start, end),
<start:@L> <kind:FunctionKind> "||" <body:BoxedExpression> <end:@R>
=> ctx.to_expr_with_source_ref(LambdaExpression{kind, params: vec![], body, outer_var_references: Default::default()}, start, end),
<start:@L> <kind:FunctionKind> "|" <params:ParameterList> "|" <body:BoxedExpression> <end:@R>
=> ctx.to_expr_with_source_ref(LambdaExpression{kind, params, body, outer_var_references: Default::default()}, start, end),
LogicalOr
}

Expand Down
25 changes: 23 additions & 2 deletions pil-analyzer/src/expression_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub struct ExpressionProcessor<'a, D: AnalysisDriver> {
type_vars: &'a HashSet<&'a String>,
local_variables: HashMap<String, u64>,
local_variable_counter: u64,
/// Tracks references to local variables to record them for closures.
local_var_references: HashSet<u64>,
}

impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
Expand All @@ -34,6 +36,7 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
type_vars,
local_variables: Default::default(),
local_variable_counter: 0,
local_var_references: Default::default(),
}
}

Expand Down Expand Up @@ -291,6 +294,7 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
match reference.try_to_identifier() {
Some(name) if self.local_variables.contains_key(name) => {
let id = self.local_variables[name];
self.local_var_references.insert(id);
Reference::LocalVar(id, name.to_string())
}
_ => Reference::Poly(self.process_namespaced_polynomial_reference(reference)),
Expand All @@ -299,9 +303,16 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {

pub fn process_lambda_expression(
&mut self,
LambdaExpression { kind, params, body }: LambdaExpression,
LambdaExpression {
kind,
params,
body,
outer_var_references: _,
}: LambdaExpression,
) -> LambdaExpression<Expression> {
let previous_local_vars = self.save_local_variables();
let previous_local_var_refs = self.local_var_references.clone();
let local_variable_height = self.local_variable_counter;

let params = params
.into_iter()
Expand All @@ -315,8 +326,18 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
}
let body = Box::new(self.process_expression(*body));

let outer_var_references =
std::mem::replace(&mut self.local_var_references, previous_local_var_refs)
.into_iter()
.filter(|id| *id < local_variable_height)
.collect();
self.reset_local_variables(previous_local_vars);
LambdaExpression { kind, params, body }
LambdaExpression {
kind,
params,
body,
outer_var_references,
}
}

fn process_block_expression(
Expand Down
1 change: 1 addition & 0 deletions pil-analyzer/src/side_effect_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ impl<'a> SideEffectChecker<'a> {
kind,
params: _,
body,
outer_var_references: _,
},
) => {
if *kind != FunctionKind::Pure && *kind != self.context {
Expand Down
10 changes: 9 additions & 1 deletion pil-analyzer/src/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,15 @@ impl TypeChecker {
.map(|item| self.infer_type_of_expression(item))
.collect::<Result<_, _>>()?,
}),
Expression::LambdaExpression(_, LambdaExpression { kind, params, body }) => {
Expression::LambdaExpression(
_,
LambdaExpression {
kind,
params,
body,
outer_var_references: _,
},
) => {
let old_len = self.local_var_types.len();
let result = params
.iter()
Expand Down
59 changes: 59 additions & 0 deletions pil-analyzer/tests/processor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::collections::BTreeSet;

use powdr_ast::{
analyzed::{Analyzed, Expression, FunctionValueDefinition},
parsed::{LambdaExpression, TypedExpression},
};
use powdr_number::GoldilocksField;

use powdr_pil_analyzer::analyze_string;

use pretty_assertions::assert_eq;

fn extract_expression<'a, T>(analyzed: &'a Analyzed<T>, name: &str) -> &'a Expression {
match analyzed.definitions[name].1.as_ref().unwrap() {
FunctionValueDefinition::Expression(TypedExpression { e, .. }) => e,
_ => panic!(),
}
}

fn outer_vars_of_lambda(expr: &Expression) -> &BTreeSet<u64> {
match expr {
Expression::LambdaExpression(
_,
LambdaExpression {
outer_var_references,
..
},
) => outer_var_references,
_ => panic!(),
}
}

#[test]
fn determine_outer_var_refs() {
let input = "
let f: int, int -> int = |i, _| i;
let g: int, int -> int = |_, i| i;
let h: int -> (int -> (int, int)) = |i| |j| (f(i, j), g(i, j));
let k: int, int -> (int -> (int, int)) = |k, i| |j| (f(i, j), g(i, j));
";

let analyzed = analyze_string::<GoldilocksField>(input);
assert!(outer_vars_of_lambda(extract_expression(&analyzed, "f")).is_empty());
assert!(outer_vars_of_lambda(extract_expression(&analyzed, "g")).is_empty());
let h = extract_expression(&analyzed, "h");
assert!(outer_vars_of_lambda(h).is_empty());
let h_body = match h {
Expression::LambdaExpression(_, LambdaExpression { body, .. }) => body,
_ => panic!(),
};
assert_eq!(outer_vars_of_lambda(h_body), &[0].into_iter().collect());
let k = extract_expression(&analyzed, "k");
assert!(outer_vars_of_lambda(k).is_empty());
let k_body = match k {
Expression::LambdaExpression(_, LambdaExpression { body, .. }) => body,
_ => panic!(),
};
assert_eq!(outer_vars_of_lambda(k_body), &[1].into_iter().collect());
}

0 comments on commit fbf2b40

Please sign in to comment.