From 68ccc89d348ce5687349e2b8567bc0b4e9399245 Mon Sep 17 00:00:00 2001 From: Folkert Date: Sat, 8 Feb 2020 16:04:00 +0100 Subject: [PATCH 1/3] instantiate rigids in nested annotations --- src/constrain/expr.rs | 28 ++++++++++++++++++++++--- src/types/mod.rs | 5 ++++- tests/test_infer.rs | 48 +++++++++++++++++++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/src/constrain/expr.rs b/src/constrain/expr.rs index 00e9477f0d..c0c7545830 100644 --- a/src/constrain/expr.rs +++ b/src/constrain/expr.rs @@ -720,13 +720,15 @@ pub fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint { let expr_con = match &def.annotation { Some((annotation, free_vars)) => { + let mut annotation = annotation.clone(); let rigids = &env.rigids; let mut ftv: ImMap = rigids.clone(); + let mut rigid_substitution: ImMap = ImMap::default(); for (var, name) in free_vars { - // if the rigid is known already, nothing needs to happen - // otherwise register it. - if !rigids.contains_key(name) { + if let Some(existing_rigid) = rigids.get(name) { + rigid_substitution.insert(*var, existing_rigid.clone()); + } else { // It's possible to use this rigid in nested defs ftv.insert(name.clone(), Type::Variable(*var)); @@ -734,6 +736,19 @@ pub fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint { } } + // Instantiate rigid variables + if !rigid_substitution.is_empty() { + annotation.substitute(&rigid_substitution); + } + + // TODO also do this for more complex patterns + if let Pattern::Identifier(symbol) = def.loc_pattern.value { + pattern_state.headers.insert( + symbol, + Located::at(def.loc_pattern.region, annotation.clone()), + ); + } + let annotation_expected = FromAnnotation( def.loc_pattern.clone(), annotation.arity(), @@ -843,6 +858,13 @@ pub fn rec_defs_help( } Some((annotation, seen_rigids)) => { + // TODO also do this for more complex patterns + if let Pattern::Identifier(symbol) = def.loc_pattern.value { + pattern_state.headers.insert( + symbol, + Located::at(def.loc_pattern.region, annotation.clone()), + ); + } let rigids = &env.rigids; let mut ftv: ImMap = rigids.clone(); diff --git a/src/types/mod.rs b/src/types/mod.rs index e25e8be5f1..ad3ae58b9f 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -296,7 +296,10 @@ impl Type { } ext.substitute(substitutions); } - Alias(_, _zipped, actual_type) => { + Alias(_, zipped, actual_type) => { + for (_, value) in zipped.iter_mut() { + value.substitute(substitutions); + } actual_type.substitute(substitutions); } Apply(_, args) => { diff --git a/tests/test_infer.rs b/tests/test_infer.rs index 9c75767056..e7de2ddbf0 100644 --- a/tests/test_infer.rs +++ b/tests/test_infer.rs @@ -58,10 +58,8 @@ mod test_infer { if !problems.is_empty() { // fail with an assert, but print the problems normally so rust doesn't try to diff // an empty vec with the problems. - panic!( - "PROBLEMS\n{:?}\nexpected:\n{:?}\ninferred:\n{:?}", - problems, expected, actual - ); + dbg!(&problems); + panic!("expected:\n{:?}\ninferred:\n{:?}", expected, actual); } assert_eq!(actual, expected.to_string()); } @@ -1675,4 +1673,46 @@ mod test_infer { "{ x : [ Attr [ Shared ]* Str ]*, y : [ Attr [ Shared ]* Str ]* }", ); } + + #[test] + fn rigid_in_let() { + infer_eq_without_problem( + indoc!( + r#" + List q : [ Cons q (List q), Nil ] + + toEmpty : List a -> List a + toEmpty = \_ -> + result : List a + result = Nil + + result + + toEmpty + "# + ), + "List a -> List a", + ); + } + + // #[test] + // fn rigid_in_let_pattern() { + // infer_eq_without_problem( + // indoc!( + // r#" + // List q : [ Cons q (List q), Nil ] + // + // { x, y } : { x : List a, y : List b } + // { x, y } = + // result : List a + // result = Nil + // + // { x : result, y : Nil } + // + // 43 + // "# + // ), + // "List a -> List a", + // ); + // } } From 2790e3717ad46a44cb8e5e9022a69f65291b99af Mon Sep 17 00:00:00 2001 From: Folkert Date: Sat, 8 Feb 2020 16:44:48 +0100 Subject: [PATCH 2/3] extract symbols from non-identifier patterns Basically, if the pattern was not just an Identifier, the defined symbols were not defined --- src/can/def.rs | 51 ++++++++++++++++++++++++++++++++----------- src/can/pattern.rs | 23 +++++++++++++++++-- src/constrain/expr.rs | 1 + src/solve.rs | 1 + tests/test_infer.rs | 31 +++++++++++++++++++------- 5 files changed, 84 insertions(+), 23 deletions(-) diff --git a/src/can/def.rs b/src/can/def.rs index 0647182ca0..bbd21b3ff4 100644 --- a/src/can/def.rs +++ b/src/can/def.rs @@ -7,9 +7,7 @@ use crate::can::expr::{ }; use crate::can::ident::{Ident, Lowercase}; use crate::can::pattern::PatternType; -use crate::can::pattern::{ - bindings_from_patterns, canonicalize_pattern, symbols_from_pattern, Pattern, -}; +use crate::can::pattern::{bindings_from_patterns, canonicalize_pattern, Pattern}; use crate::can::problem::Problem; use crate::can::problem::RuntimeError; use crate::can::procedure::References; @@ -616,13 +614,42 @@ fn group_to_declaration( } } +fn pattern_to_vars_by_symbol( + vars_by_symbol: &mut SendMap, + pattern: &Pattern, + expr_var: Variable, +) { + use Pattern::*; + match pattern { + Identifier(symbol) => { + vars_by_symbol.insert(symbol.clone(), expr_var); + } + + AppliedTag(_, _, arguments) => { + for (var, nested) in arguments { + pattern_to_vars_by_symbol(vars_by_symbol, &nested.value, *var); + } + } + + RecordDestructure(_, destructs) => { + for destruct in destructs { + vars_by_symbol.insert(destruct.value.symbol.clone(), destruct.value.var); + } + } + + IntLiteral(_) | FloatLiteral(_) | StrLiteral(_) | Underscore | UnsupportedPattern(_) => {} + + Shadowed(_, _) => {} + } +} + // TODO trim down these arguments! #[allow(clippy::too_many_arguments)] fn canonicalize_pending_def<'a>( env: &mut Env<'a>, found_rigids: &mut SendMap, pending_def: PendingDef<'a>, - original_scope: &Scope, + _original_scope: &Scope, scope: &mut Scope, can_defs_by_symbol: &mut MutMap, var_store: &VarStore, @@ -646,6 +673,8 @@ fn canonicalize_pending_def<'a>( found_rigids.insert(k, v); } + pattern_to_vars_by_symbol(&mut vars_by_symbol, &loc_can_pattern.value, expr_var); + let typ = ann.typ; let arity = typ.arity(); @@ -693,9 +722,8 @@ fn canonicalize_pending_def<'a>( } }; - for (ident, (symbol, _)) in scope.idents() { - // TODO Could we do this by symbol instead, to avoid cloning idents? - if original_scope.contains_ident(ident) { + for (_, (symbol, _)) in scope.idents() { + if !vars_by_symbol.contains_key(&symbol) { continue; } @@ -750,9 +778,10 @@ fn canonicalize_pending_def<'a>( if let Pattern::Identifier(ref defined_symbol) = &loc_can_pattern.value { env.tailcallable_symbol = Some(*defined_symbol); - vars_by_symbol.insert(defined_symbol.clone(), expr_var); }; + pattern_to_vars_by_symbol(&mut vars_by_symbol, &loc_can_pattern.value, expr_var); + let (mut loc_can_expr, can_output) = canonicalize_expr(env, var_store, scope, loc_expr.region, &loc_expr.value); @@ -818,14 +847,10 @@ fn canonicalize_pending_def<'a>( ); } - let symbols_defined_here: ImSet = symbols_from_pattern(&loc_can_pattern.value) - .into_iter() - .collect(); - // Store the referenced locals in the refs_by_symbol map, so we can later figure out // which defined names reference each other. for (ident, (symbol, region)) in scope.idents() { - if !symbols_defined_here.contains(&symbol) { + if !vars_by_symbol.contains_key(&symbol) { continue; } diff --git a/src/can/pattern.rs b/src/can/pattern.rs index 2af037daf1..d958102772 100644 --- a/src/can/pattern.rs +++ b/src/can/pattern.rs @@ -43,8 +43,27 @@ pub fn symbols_from_pattern(pattern: &Pattern) -> Vec { } pub fn symbols_from_pattern_help(pattern: &Pattern, symbols: &mut Vec) { - if let Pattern::Identifier(symbol) = pattern { - symbols.push(symbol.clone()); + use Pattern::*; + + match pattern { + Identifier(symbol) => { + symbols.push(symbol.clone()); + } + + AppliedTag(_, _, arguments) => { + for (_, nested) in arguments { + symbols_from_pattern_help(&nested.value, symbols); + } + } + RecordDestructure(_, destructs) => { + for destruct in destructs { + symbols.push(destruct.value.symbol.clone()); + } + } + + IntLiteral(_) | FloatLiteral(_) | StrLiteral(_) | Underscore | UnsupportedPattern(_) => {} + + Shadowed(_, _) => {} } } diff --git a/src/constrain/expr.rs b/src/constrain/expr.rs index c0c7545830..471af1f3fe 100644 --- a/src/constrain/expr.rs +++ b/src/constrain/expr.rs @@ -61,6 +61,7 @@ pub fn constrain_expr( expr: &Expr, expected: Expected, ) -> Constraint { + dbg!(&expr); match expr { Int(var, _) => int_literal(*var, expected, region), Float(var, _) => float_literal(*var, expected, region), diff --git a/src/solve.rs b/src/solve.rs index 86dad89151..6102baf9fa 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -240,6 +240,7 @@ pub fn run( mut subs: Subs, constraint: &Constraint, ) -> (Solved, Env) { + dbg!(&constraint); let mut pools = Pools::default(); let state = State { vars_by_symbol: vars_by_symbol.clone(), diff --git a/tests/test_infer.rs b/tests/test_infer.rs index e7de2ddbf0..b8caca0ee2 100644 --- a/tests/test_infer.rs +++ b/tests/test_infer.rs @@ -1695,22 +1695,37 @@ mod test_infer { ); } + #[test] + fn let_record_pattern_with_annotation() { + infer_eq_without_problem( + indoc!( + r#" + { x, y } : { x : Str.Str, y : Num.Num Float.FloatingPoint } + { x, y } = { x : "foo", y : 3.14 } + + x + "# + ), + "Str", + ); + } + // #[test] // fn rigid_in_let_pattern() { // infer_eq_without_problem( // indoc!( // r#" - // List q : [ Cons q (List q), Nil ] + // List q : [ Cons q (List q), Nil ] // - // { x, y } : { x : List a, y : List b } - // { x, y } = - // result : List a - // result = Nil + // { x, y } : { x : List a, y : List b } + // { x, y } = + // result : List a + // result = Nil // - // { x : result, y : Nil } + // { x : result, y : Nil } // - // 43 - // "# + // 43 + // "# // ), // "List a -> List a", // ); From ecd451e84bec1dcf3ba9ef0927f7e47080cf22fa Mon Sep 17 00:00:00 2001 From: Folkert Date: Sat, 8 Feb 2020 21:12:33 +0100 Subject: [PATCH 3/3] turn annotations into headers --- src/constrain/expr.rs | 20 +++++----- src/constrain/pattern.rs | 86 ++++++++++++++++++++++++++++++++++++++++ src/solve.rs | 1 - src/types/mod.rs | 8 ++++ tests/test_infer.rs | 35 ++++++++++------ 5 files changed, 128 insertions(+), 22 deletions(-) diff --git a/src/constrain/expr.rs b/src/constrain/expr.rs index 471af1f3fe..2c9fef40bf 100644 --- a/src/constrain/expr.rs +++ b/src/constrain/expr.rs @@ -61,7 +61,6 @@ pub fn constrain_expr( expr: &Expr, expected: Expected, ) -> Constraint { - dbg!(&expr); match expr { Int(var, _) => int_literal(*var, expected, region), Float(var, _) => float_literal(*var, expected, region), @@ -742,19 +741,22 @@ pub fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint { annotation.substitute(&rigid_substitution); } - // TODO also do this for more complex patterns - if let Pattern::Identifier(symbol) = def.loc_pattern.value { - pattern_state.headers.insert( - symbol, - Located::at(def.loc_pattern.region, annotation.clone()), - ); + let arity = annotation.arity(); + + if let Some(headers) = crate::constrain::pattern::headers_from_annotation( + &def.loc_pattern.value, + &Located::at(def.loc_pattern.region, annotation.clone()), + ) { + for (k, v) in headers { + pattern_state.headers.insert(k, v); + } } let annotation_expected = FromAnnotation( def.loc_pattern.clone(), - annotation.arity(), + arity, AnnotationSource::TypedBody, - annotation.clone(), + annotation, ); pattern_state.constraints.push(Eq( diff --git a/src/constrain/pattern.rs b/src/constrain/pattern.rs index 060c5f4960..4878758c56 100644 --- a/src/constrain/pattern.rs +++ b/src/constrain/pattern.rs @@ -13,6 +13,92 @@ pub struct PatternState { pub constraints: Vec, } +/// If there is a type annotation, the pattern state headers can be optimized by putting the +/// annotation in the headers. Normally +/// +/// x = 4 +/// +/// Would add `x => <42>` to the headers (i.e., symbol points to a type variable). If the +/// definition has an annotation, we instead now add `x => Int`. +pub fn headers_from_annotation( + pattern: &Pattern, + annotation: &Located, +) -> Option>> { + let mut headers = SendMap::default(); + // Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int` + // in such incorrect cases we don't put the full annotation in headers, just a variable, and let + // inference generate a proper error. + let is_structurally_valid = headers_from_annotation_help(pattern, annotation, &mut headers); + + if is_structurally_valid { + Some(headers) + } else { + None + } +} + +pub fn headers_from_annotation_help( + pattern: &Pattern, + annotation: &Located, + headers: &mut SendMap>, +) -> bool { + match pattern { + Identifier(symbol) => { + headers.insert(symbol.clone(), annotation.clone()); + true + } + Underscore + | Shadowed(_, _) + | UnsupportedPattern(_) + | IntLiteral(_) + | FloatLiteral(_) + | StrLiteral(_) => true, + + RecordDestructure(_, destructs) => match annotation.value.shallow_dealias() { + Type::Record(fields, _) => { + for destruct in destructs { + // NOTE ignores the .guard field. + if let Some(field_type) = fields.get(&destruct.value.label) { + headers.insert( + destruct.value.symbol.clone(), + Located::at(annotation.region, field_type.clone()), + ); + } else { + return false; + } + } + true + } + Type::EmptyRec => destructs.is_empty(), + _ => false, + }, + + AppliedTag(_, tag_name, arguments) => match annotation.value.shallow_dealias() { + Type::TagUnion(tags, _) => { + if let Some((_, arg_types)) = tags.iter().find(|(name, _)| name == tag_name) { + if !arguments.len() == arg_types.len() { + return false; + } + + arguments + .iter() + .zip(arg_types.iter()) + .all(|(arg_pattern, arg_type)| { + headers_from_annotation_help( + &arg_pattern.1.value, + &Located::at(annotation.region, arg_type.clone()), + headers, + ) + }) + } else { + false + } + } + _ => false, + }, + } +} + /// This accepts PatternState (rather than returning it) so that the caller can /// intiialize the Vecs in PatternState using with_capacity /// based on its knowledge of their lengths. diff --git a/src/solve.rs b/src/solve.rs index 6102baf9fa..86dad89151 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -240,7 +240,6 @@ pub fn run( mut subs: Subs, constraint: &Constraint, ) -> (Solved, Env) { - dbg!(&constraint); let mut pools = Pools::default(); let state = State { vars_by_symbol: vars_by_symbol.clone(), diff --git a/src/types/mod.rs b/src/types/mod.rs index ad3ae58b9f..d6db1c2b03 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -386,6 +386,14 @@ impl Type { EmptyRec | EmptyTagUnion | Erroneous(_) | Variable(_) | Boolean(_) => false, } } + + /// a shallow dealias, continue until the first constructor is not an alias. + pub fn shallow_dealias(&self) -> &Self { + match self { + Type::Alias(_, _, actual) => actual.shallow_dealias(), + _ => self, + } + } } fn variables_help(tipe: &Type, accum: &mut ImSet) { diff --git a/tests/test_infer.rs b/tests/test_infer.rs index b8caca0ee2..24ac556bb5 100644 --- a/tests/test_infer.rs +++ b/tests/test_infer.rs @@ -1710,24 +1710,35 @@ mod test_infer { ); } + #[test] + fn let_record_pattern_with_alias_annotation() { + infer_eq_without_problem( + indoc!( + r#" + Foo : { x : Str.Str, y : Num.Num Float.FloatingPoint } + + { x, y } : Foo + { x, y } = { x : "foo", y : 3.14 } + + x + "# + ), + "Str", + ); + } + // #[test] - // fn rigid_in_let_pattern() { + // fn let_tag_pattern_with_annotation() { // infer_eq_without_problem( // indoc!( // r#" - // List q : [ Cons q (List q), Nil ] - // - // { x, y } : { x : List a, y : List b } - // { x, y } = - // result : List a - // result = Nil - // - // { x : result, y : Nil } + // UserId x : [ UserId Int ] + // UserId x = UserId 42 // - // 43 - // "# + // x + // "# // ), - // "List a -> List a", + // "Int", // ); // } }