diff --git a/src/can/def.rs b/src/can/def.rs index 0647182ca0..bbd21b3ff4 100644 --- a/src/can/def.rs +++ b/src/can/def.rs @@ -7,9 +7,7 @@ use crate::can::expr::{ }; use crate::can::ident::{Ident, Lowercase}; use crate::can::pattern::PatternType; -use crate::can::pattern::{ - bindings_from_patterns, canonicalize_pattern, symbols_from_pattern, Pattern, -}; +use crate::can::pattern::{bindings_from_patterns, canonicalize_pattern, Pattern}; use crate::can::problem::Problem; use crate::can::problem::RuntimeError; use crate::can::procedure::References; @@ -616,13 +614,42 @@ fn group_to_declaration( } } +fn pattern_to_vars_by_symbol( + vars_by_symbol: &mut SendMap, + pattern: &Pattern, + expr_var: Variable, +) { + use Pattern::*; + match pattern { + Identifier(symbol) => { + vars_by_symbol.insert(symbol.clone(), expr_var); + } + + AppliedTag(_, _, arguments) => { + for (var, nested) in arguments { + pattern_to_vars_by_symbol(vars_by_symbol, &nested.value, *var); + } + } + + RecordDestructure(_, destructs) => { + for destruct in destructs { + vars_by_symbol.insert(destruct.value.symbol.clone(), destruct.value.var); + } + } + + IntLiteral(_) | FloatLiteral(_) | StrLiteral(_) | Underscore | UnsupportedPattern(_) => {} + + Shadowed(_, _) => {} + } +} + // TODO trim down these arguments! #[allow(clippy::too_many_arguments)] fn canonicalize_pending_def<'a>( env: &mut Env<'a>, found_rigids: &mut SendMap, pending_def: PendingDef<'a>, - original_scope: &Scope, + _original_scope: &Scope, scope: &mut Scope, can_defs_by_symbol: &mut MutMap, var_store: &VarStore, @@ -646,6 +673,8 @@ fn canonicalize_pending_def<'a>( found_rigids.insert(k, v); } + pattern_to_vars_by_symbol(&mut vars_by_symbol, &loc_can_pattern.value, expr_var); + let typ = ann.typ; let arity = typ.arity(); @@ -693,9 +722,8 @@ fn canonicalize_pending_def<'a>( } }; - for (ident, (symbol, _)) in scope.idents() { - // TODO Could we do this by symbol instead, to avoid cloning idents? - if original_scope.contains_ident(ident) { + for (_, (symbol, _)) in scope.idents() { + if !vars_by_symbol.contains_key(&symbol) { continue; } @@ -750,9 +778,10 @@ fn canonicalize_pending_def<'a>( if let Pattern::Identifier(ref defined_symbol) = &loc_can_pattern.value { env.tailcallable_symbol = Some(*defined_symbol); - vars_by_symbol.insert(defined_symbol.clone(), expr_var); }; + pattern_to_vars_by_symbol(&mut vars_by_symbol, &loc_can_pattern.value, expr_var); + let (mut loc_can_expr, can_output) = canonicalize_expr(env, var_store, scope, loc_expr.region, &loc_expr.value); @@ -818,14 +847,10 @@ fn canonicalize_pending_def<'a>( ); } - let symbols_defined_here: ImSet = symbols_from_pattern(&loc_can_pattern.value) - .into_iter() - .collect(); - // Store the referenced locals in the refs_by_symbol map, so we can later figure out // which defined names reference each other. for (ident, (symbol, region)) in scope.idents() { - if !symbols_defined_here.contains(&symbol) { + if !vars_by_symbol.contains_key(&symbol) { continue; } diff --git a/src/can/pattern.rs b/src/can/pattern.rs index 2af037daf1..d958102772 100644 --- a/src/can/pattern.rs +++ b/src/can/pattern.rs @@ -43,8 +43,27 @@ pub fn symbols_from_pattern(pattern: &Pattern) -> Vec { } pub fn symbols_from_pattern_help(pattern: &Pattern, symbols: &mut Vec) { - if let Pattern::Identifier(symbol) = pattern { - symbols.push(symbol.clone()); + use Pattern::*; + + match pattern { + Identifier(symbol) => { + symbols.push(symbol.clone()); + } + + AppliedTag(_, _, arguments) => { + for (_, nested) in arguments { + symbols_from_pattern_help(&nested.value, symbols); + } + } + RecordDestructure(_, destructs) => { + for destruct in destructs { + symbols.push(destruct.value.symbol.clone()); + } + } + + IntLiteral(_) | FloatLiteral(_) | StrLiteral(_) | Underscore | UnsupportedPattern(_) => {} + + Shadowed(_, _) => {} } } diff --git a/src/constrain/expr.rs b/src/constrain/expr.rs index 00e9477f0d..2c9fef40bf 100644 --- a/src/constrain/expr.rs +++ b/src/constrain/expr.rs @@ -720,13 +720,15 @@ pub fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint { let expr_con = match &def.annotation { Some((annotation, free_vars)) => { + let mut annotation = annotation.clone(); let rigids = &env.rigids; let mut ftv: ImMap = rigids.clone(); + let mut rigid_substitution: ImMap = ImMap::default(); for (var, name) in free_vars { - // if the rigid is known already, nothing needs to happen - // otherwise register it. - if !rigids.contains_key(name) { + if let Some(existing_rigid) = rigids.get(name) { + rigid_substitution.insert(*var, existing_rigid.clone()); + } else { // It's possible to use this rigid in nested defs ftv.insert(name.clone(), Type::Variable(*var)); @@ -734,11 +736,27 @@ pub fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint { } } + // Instantiate rigid variables + if !rigid_substitution.is_empty() { + annotation.substitute(&rigid_substitution); + } + + let arity = annotation.arity(); + + if let Some(headers) = crate::constrain::pattern::headers_from_annotation( + &def.loc_pattern.value, + &Located::at(def.loc_pattern.region, annotation.clone()), + ) { + for (k, v) in headers { + pattern_state.headers.insert(k, v); + } + } + let annotation_expected = FromAnnotation( def.loc_pattern.clone(), - annotation.arity(), + arity, AnnotationSource::TypedBody, - annotation.clone(), + annotation, ); pattern_state.constraints.push(Eq( @@ -843,6 +861,13 @@ pub fn rec_defs_help( } Some((annotation, seen_rigids)) => { + // TODO also do this for more complex patterns + if let Pattern::Identifier(symbol) = def.loc_pattern.value { + pattern_state.headers.insert( + symbol, + Located::at(def.loc_pattern.region, annotation.clone()), + ); + } let rigids = &env.rigids; let mut ftv: ImMap = rigids.clone(); diff --git a/src/constrain/pattern.rs b/src/constrain/pattern.rs index 060c5f4960..4878758c56 100644 --- a/src/constrain/pattern.rs +++ b/src/constrain/pattern.rs @@ -13,6 +13,92 @@ pub struct PatternState { pub constraints: Vec, } +/// If there is a type annotation, the pattern state headers can be optimized by putting the +/// annotation in the headers. Normally +/// +/// x = 4 +/// +/// Would add `x => <42>` to the headers (i.e., symbol points to a type variable). If the +/// definition has an annotation, we instead now add `x => Int`. +pub fn headers_from_annotation( + pattern: &Pattern, + annotation: &Located, +) -> Option>> { + let mut headers = SendMap::default(); + // Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int` + // in such incorrect cases we don't put the full annotation in headers, just a variable, and let + // inference generate a proper error. + let is_structurally_valid = headers_from_annotation_help(pattern, annotation, &mut headers); + + if is_structurally_valid { + Some(headers) + } else { + None + } +} + +pub fn headers_from_annotation_help( + pattern: &Pattern, + annotation: &Located, + headers: &mut SendMap>, +) -> bool { + match pattern { + Identifier(symbol) => { + headers.insert(symbol.clone(), annotation.clone()); + true + } + Underscore + | Shadowed(_, _) + | UnsupportedPattern(_) + | IntLiteral(_) + | FloatLiteral(_) + | StrLiteral(_) => true, + + RecordDestructure(_, destructs) => match annotation.value.shallow_dealias() { + Type::Record(fields, _) => { + for destruct in destructs { + // NOTE ignores the .guard field. + if let Some(field_type) = fields.get(&destruct.value.label) { + headers.insert( + destruct.value.symbol.clone(), + Located::at(annotation.region, field_type.clone()), + ); + } else { + return false; + } + } + true + } + Type::EmptyRec => destructs.is_empty(), + _ => false, + }, + + AppliedTag(_, tag_name, arguments) => match annotation.value.shallow_dealias() { + Type::TagUnion(tags, _) => { + if let Some((_, arg_types)) = tags.iter().find(|(name, _)| name == tag_name) { + if !arguments.len() == arg_types.len() { + return false; + } + + arguments + .iter() + .zip(arg_types.iter()) + .all(|(arg_pattern, arg_type)| { + headers_from_annotation_help( + &arg_pattern.1.value, + &Located::at(annotation.region, arg_type.clone()), + headers, + ) + }) + } else { + false + } + } + _ => false, + }, + } +} + /// This accepts PatternState (rather than returning it) so that the caller can /// intiialize the Vecs in PatternState using with_capacity /// based on its knowledge of their lengths. diff --git a/src/types/mod.rs b/src/types/mod.rs index dab1864b12..c7929c5aa2 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -298,7 +298,10 @@ impl Type { } ext.substitute(substitutions); } - Alias(_, _zipped, actual_type) => { + Alias(_, zipped, actual_type) => { + for (_, value) in zipped.iter_mut() { + value.substitute(substitutions); + } actual_type.substitute(substitutions); } Apply(_, args) => { @@ -385,6 +388,14 @@ impl Type { EmptyRec | EmptyTagUnion | Erroneous(_) | Variable(_) | Boolean(_) => false, } } + + /// a shallow dealias, continue until the first constructor is not an alias. + pub fn shallow_dealias(&self) -> &Self { + match self { + Type::Alias(_, _, actual) => actual.shallow_dealias(), + _ => self, + } + } } fn variables_help(tipe: &Type, accum: &mut ImSet) { diff --git a/tests/test_infer.rs b/tests/test_infer.rs index bac954ab74..f71dbe5b37 100644 --- a/tests/test_infer.rs +++ b/tests/test_infer.rs @@ -58,10 +58,8 @@ mod test_infer { if !problems.is_empty() { // fail with an assert, but print the problems normally so rust doesn't try to diff // an empty vec with the problems. - panic!( - "PROBLEMS\n{:?}\nexpected:\n{:?}\ninferred:\n{:?}", - problems, expected, actual - ); + dbg!(&problems); + panic!("expected:\n{:?}\ninferred:\n{:?}", expected, actual); } assert_eq!(actual, expected.to_string()); } @@ -1676,6 +1674,27 @@ mod test_infer { ); } + #[test] + fn rigid_in_let() { + infer_eq_without_problem( + indoc!( + r#" + List q : [ Cons q (List q), Nil ] + + toEmpty : List a -> List a + toEmpty = \_ -> + result : List a + result = Nil + + result + + toEmpty + "# + ), + "List a -> List a", + ); + } + #[test] fn peano_map_alias() { infer_eq( @@ -1698,6 +1717,21 @@ mod test_infer { ); } + #[test] + fn let_record_pattern_with_annotation() { + infer_eq_without_problem( + indoc!( + r#" + { x, y } : { x : Str.Str, y : Num.Num Float.FloatingPoint } + { x, y } = { x : "foo", y : 3.14 } + + x + "# + ), + "Str", + ); + } + #[test] fn peano_map_infer() { infer_eq( @@ -1717,6 +1751,38 @@ mod test_infer { ); } + #[test] + fn let_record_pattern_with_alias_annotation() { + infer_eq_without_problem( + indoc!( + r#" + Foo : { x : Str.Str, y : Num.Num Float.FloatingPoint } + + { x, y } : Foo + { x, y } = { x : "foo", y : 3.14 } + + x + "# + ), + "Str", + ); + } + + // #[test] + // fn let_tag_pattern_with_annotation() { + // infer_eq_without_problem( + // indoc!( + // r#" + // UserId x : [ UserId Int ] + // UserId x = UserId 42 + // + // x + // "# + // ), + // "Int", + // ); + // } + #[test] fn typecheck_record_linked_list_map() { infer_eq_without_problem(