Skip to content

Commit

Permalink
Merge pull request #182 from tweag/feature/improve-row-types-syntax
Browse files Browse the repository at this point in the history
Simplify row types syntax
  • Loading branch information
edolstra committed Nov 13, 2020
2 parents 05ce1dd + f3b9cd6 commit 6a93ad8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 72 deletions.
6 changes: 3 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ pub enum TypecheckError {
/// This is similar to `RowKindMismatch` but occurs in a slightly different situation. Consider a a
/// unification variable `t`, which is a placeholder to be filled by a concrete type later in
/// the typechecking phase. If `t` appears as the tail of a row type, i.e. the type of some
/// expression is inferred to be `{| field: Type | t}`, then `t` must not be unified later with
/// expression is inferred to be `{ field: Type | t}`, then `t` must not be unified later with
/// a type including a different declaration for field, such as `field: Type2`.
///
/// A [constraint](../typecheck/type.RowConstr.html) is added accordingly, and if this
/// constraint is violated (that is if `t` does end up being unified with a type of the form
/// `{| .., field: Type2, .. }`), `RowConflict` is raised. We do not have access to the
/// `{ .., field: Type2, .. }`), `RowConflict` is raised. We do not have access to the
/// original `field: Type` declaration, as opposed to `RowKindMismatch`, which corresponds to the
/// direct failure to unify `{| .. , x: T1, .. }` and `{| .., x: T2, .. }`.
/// direct failure to unify `{ .. , x: T1, .. }` and `{ .., x: T2, .. }`.
RowConflict(
Ident,
/* the second type assignment which violates the constraint */ Option<Types>,
Expand Down
25 changes: 14 additions & 11 deletions src/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ subType : Types = {
<Ident> => Types(AbsType::Var(<>)),
"#" <SpTerm<Atom>> => Types(AbsType::Flat(<>)),
"(" <Types> ")" => <>,
"(" "|" <r:(<Ident> ",")*> <last: (<Ident>)?>"|" <rest: Ident?> ")" =>
r.into_iter()
"<" <rows:(<Ident> ",")*> <last: (<Ident>)?> <tail: ("|" <Ident>)?> ">" => {
let ty = rows.into_iter()
.chain(last.into_iter())
// As we build row types as a linked list via a fold on the original
// iterator, the order of identifiers is reversed. This not a big deal
Expand All @@ -345,16 +345,19 @@ subType : Types = {
.rev()
.fold(
Types(
match rest {
match tail {
Some(id) => AbsType::Var(id),
None => AbsType::RowEmpty(),
}
),
|t, i| Types(AbsType::RowExtend(i, None, Box::new(t)))
),
"{" "|" <r:(<Ident> ":" <Types> ",")*> <last:(<Ident> ":" <Types>)?> "|"
<rest: Ident?> "}" =>
r.into_iter()
);
Types(AbsType::Enum(Box::new(ty)))
},
"{" <rows:(<Ident> ":" <Types> ",")*>
<last:(<Ident> ":" <Types>)?>
<tail: ("|" <Ident>)?> "}" => {
let ty = rows.into_iter()
.chain(last.into_iter())
// As we build row types as a linked list via a fold on the original
// iterator, the order of identifiers is reversed. This not a big deal
Expand All @@ -363,7 +366,7 @@ subType : Types = {
.rev()
.fold(
Types(
match rest {
match tail {
Some(id) => AbsType::Var(id),
None => AbsType::RowEmpty(),
}
Expand All @@ -372,9 +375,9 @@ subType : Types = {
let (i, ty) = i_ty;
Types(AbsType::RowExtend(i, Some(Box::new(ty)), Box::new(t)))
}
),
"<" <subType> ">" => Types(AbsType::Enum(Box::new(<>))),
"{" <subType> "}" => Types(AbsType::StaticRecord(Box::new(<>))),
);
Types(AbsType::StaticRecord(Box::new(ty)))
},
"{" "_" ":" <Types> "}" => Types(AbsType::DynRecord(Box::new(<>))),
};

Expand Down
22 changes: 11 additions & 11 deletions src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ Assume(#alwaysTrue -> #alwaysFalse, not ) true
#[test]
fn safe_id() {
let res = eval_string(
"let id = Assume(forall a . a -> a, fun x => x) in
"let id = Assume(forall a. a -> a, fun x => x) in
id false",
);
assert_eq!(Ok(Term::Bool(false)), res);
Expand All @@ -631,7 +631,7 @@ Assume(#alwaysTrue -> #alwaysFalse, not ) true
#[test]
fn id_fail() {
let res = eval_string(
"let id = Assume(forall a . a -> a, fun x => false) in
"let id = Assume(forall a. a -> a, fun x => false) in
id false",
);
if let Ok(_) = res {
Expand All @@ -642,7 +642,7 @@ Assume(#alwaysTrue -> #alwaysFalse, not ) true
#[test]
fn safe_higher_order() {
let res = eval_string(
"let to_bool = Assume(forall a . (a -> Bool) -> a -> Bool,
"let to_bool = Assume(forall a. (a -> Bool) -> a -> Bool,
fun f => fun x => f x) in
to_bool (fun x => true) 4 ",
);
Expand All @@ -652,7 +652,7 @@ Assume(#alwaysTrue -> #alwaysFalse, not ) true
#[test]
fn apply_twice() {
let res = eval_string(
"let twice = Assume(forall a . (a -> a) -> a -> a,
"let twice = Assume(forall a. (a -> a) -> a -> a,
fun f => fun x => f (f x)) in
twice (fun x => x + 1) 3",
);
Expand All @@ -670,33 +670,33 @@ Assume(#alwaysTrue -> #alwaysFalse, not ) true

#[test]
fn enum_simple() {
let res = eval_string("Promise(< (| foo, bar, |) >, `foo)");
let res = eval_string("Promise(<foo, bar>, `foo)");
assert_eq!(res, Ok(Term::Enum(Ident("foo".to_string()))));

let res = eval_string("Promise(forall r. (< (| foo, bar, | r ) >), `bar)");
let res = eval_string("Promise(forall r. <foo, bar | r>, `bar)");
assert_eq!(res, Ok(Term::Enum(Ident("bar".to_string()))));

eval_string("Promise(< (| foo, bar, |) >, `far)").unwrap_err();
eval_string("Promise(<foo, bar>, `far)").unwrap_err();
}

#[test]
fn enum_complex() {
let res = eval_string(
"let f = Promise(forall r. < (| foo, bar, | r ) > -> Num,
"let f = Promise(forall r. <foo, bar | r> -> Num,
fun x => switch { foo => 1, bar => 2, _ => 3, } x) in
f `bar",
);
assert_eq!(res, Ok(Term::Num(2.)));

let res = eval_string(
"let f = Promise(forall r. < (| foo, bar, | r ) > -> Num,
"let f = Promise(forall r. <foo, bar | r> -> Num,
fun x => switch { foo => 1, bar => 2, _ => 3, } x) in
f `boo",
);
assert_eq!(res, Ok(Term::Num(3.)));

eval_string(
"let f = Promise(< (| foo, bar, |) > -> Num,
"let f = Promise(<foo, bar> -> Num,
fun x => switch { foo => 1, bar => 2, } x) in
f `boo",
)
Expand All @@ -714,7 +714,7 @@ Assume(#alwaysTrue -> #alwaysFalse, not ) true

#[test]
fn row_types() {
eval_string("Assume((| |), 123)").unwrap_err();
eval_string("Assume(< >, 123)").unwrap_err();
}

#[test]
Expand Down
71 changes: 31 additions & 40 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1994,12 +1994,12 @@ mod tests {

#[test]
fn enum_simple() {
parse_and_typecheck("Promise(< (| bla, |) >, `bla)").unwrap();
parse_and_typecheck("Promise(< (| bla, |) >, `blo)").unwrap_err();
parse_and_typecheck("Promise(<bla>, `bla)").unwrap();
parse_and_typecheck("Promise(<bla>, `blo)").unwrap_err();

parse_and_typecheck("Promise(< (| bla, blo, |) >, `blo)").unwrap();
parse_and_typecheck("Promise(forall r. < (| bla, | r ) >, `bla)").unwrap();
parse_and_typecheck("Promise(forall r. < (| bla, blo, | r ) >, `bla)").unwrap();
parse_and_typecheck("Promise(<bla, blo>, `blo)").unwrap();
parse_and_typecheck("Promise(forall r. <bla | r>, `bla)").unwrap();
parse_and_typecheck("Promise(forall r. <bla, blo | r>, `bla)").unwrap();

parse_and_typecheck("Promise(Num, switch { bla => 3, } `bla)").unwrap();
parse_and_typecheck("Promise(Num, switch { bla => 3, } `blo)").unwrap_err();
Expand All @@ -2010,17 +2010,15 @@ mod tests {

#[test]
fn enum_complex() {
parse_and_typecheck("Promise(<bla, ble> -> Num, fun x => switch {bla => 1, ble => 2,} x)")
.unwrap();
parse_and_typecheck(
"Promise(< (| bla, ble, |) > -> Num, fun x => switch {bla => 1, ble => 2,} x)",
)
.unwrap();
parse_and_typecheck(
"Promise(< (| bla, ble, |) > -> Num,
"Promise(<bla, ble> -> Num,
fun x => switch {bla => 1, ble => 2, bli => 4,} x)",
)
.unwrap_err();
parse_and_typecheck(
"Promise(< (| bla, ble, |) > -> Num,
"Promise(<bla, ble> -> Num,
fun x => switch {bla => 1, ble => 2, bli => 4,} (embed bli x))",
)
.unwrap();
Expand All @@ -2043,29 +2041,29 @@ mod tests {

parse_and_typecheck(
"let f = Promise(
forall r. < (| blo, ble, | r )> -> Num,
forall r. <blo, ble | r> -> Num,
fun x => (switch {blo => 1, ble => 2, _ => 3, } x ) ) in
Promise(Num, f `bli)",
)
.unwrap();
parse_and_typecheck(
"let f = Promise(
forall r. < (| blo, ble, | r )> -> Num,
forall r. <blo, ble, | r> -> Num,
fun x => (switch {blo => 1, ble => 2, bli => 3, } x ) ) in
f",
)
.unwrap_err();

parse_and_typecheck(
"let f = Promise(
forall r. (forall p. < (| blo, ble, | r )> -> < (| bla, bli, | p) > ),
forall r. (forall p. <blo, ble | r> -> <bla, bli | p> ),
fun x => (switch {blo => `bla, ble => `bli, _ => `bla, } x ) ) in
f `bli",
)
.unwrap();
parse_and_typecheck(
"let f = Promise(
forall r. (forall p. < (| blo, ble, | r )> -> < (| bla, bli, | p) > ),
forall r. (forall p. <blo, ble | r> -> <bla, bli | p> ),
fun x => (switch {blo => `bla, ble => `bli, _ => `blo, } x ) ) in
f `bli",
)
Expand All @@ -2074,27 +2072,26 @@ mod tests {

#[test]
fn static_record_simple() {
parse_and_typecheck("Promise({ {| bla : Num, |} }, { bla = 1; })").unwrap();
parse_and_typecheck("Promise({ {| bla : Num, |} }, { bla = true; })").unwrap_err();
parse_and_typecheck("Promise({ {| bla : Num, |} }, { blo = 1; })").unwrap_err();
parse_and_typecheck("Promise({bla : Num, }, { bla = 1; })").unwrap();
parse_and_typecheck("Promise({bla : Num, }, { bla = true; })").unwrap_err();
parse_and_typecheck("Promise({bla : Num, }, { blo = 1; })").unwrap_err();

parse_and_typecheck("Promise({ {| bla : Num, blo : Bool, |} }, { blo = true; bla = 1; })")
.unwrap();
parse_and_typecheck("Promise({bla : Num, blo : Bool}, { blo = true; bla = 1; })").unwrap();

parse_and_typecheck("Promise(Num, { blo = 1; }.blo)").unwrap();
parse_and_typecheck("Promise(Num, { bla = true; blo = 1; }.blo)").unwrap();
parse_and_typecheck("Promise(Bool, { blo = 1; }.blo)").unwrap_err();

parse_and_typecheck(
"let r = Promise({ {| bla : Bool, blo : Num, |} }, {blo = 1; bla = true; }) in
"let r = Promise({bla : Bool, blo : Num}, {blo = 1; bla = true; }) in
Promise(Num, if r.bla then r.blo else 2)",
)
.unwrap();

// It worked at first try :O
parse_and_typecheck(
"let f = Promise(
forall a. (forall r. { {| bla : Bool, blo : a, ble : a, | r } } -> a),
forall a. (forall r. {bla : Bool, blo : a, ble : a | r} -> a),
fun r => if r.bla then r.blo else r.ble)
in
Promise(Num,
Expand All @@ -2107,7 +2104,7 @@ mod tests {

parse_and_typecheck(
"let f = Promise(
forall a. (forall r. { {| bla : Bool, blo : a, ble : a, | r } } -> a),
forall a. (forall r. {bla : Bool, blo : a, ble : a | r} -> a),
fun r => if r.bla then r.blo else r.ble)
in
Promise(Num,
Expand All @@ -2117,7 +2114,7 @@ mod tests {
.unwrap_err();
parse_and_typecheck(
"let f = Promise(
forall a. (forall r. { {| bla : Bool, blo : a, ble : a, | r } } -> a),
forall a. (forall r. {bla : Bool, blo : a, ble : a | r} -> a),
fun r => if r.bla then (r.blo + 1) else r.ble)
in
Promise(Num,
Expand Down Expand Up @@ -2227,19 +2224,13 @@ mod tests {

#[test]
fn recursive_records() {
parse_and_typecheck(
"Promise({ {| a : Num, b : Num, |} }, { a = Promise(Num,1); b = a + 1})",
)
.unwrap();
parse_and_typecheck(
"Promise({ {| a : Num, b : Num, |} }, { a = Promise(Num,true); b = a + 1})",
)
.unwrap_err();
parse_and_typecheck(
"Promise({ {| a : Num, b : Bool, |} }, { a = 1; b = Promise(Bool, a) } )",
)
.unwrap_err();
parse_and_typecheck("Promise({ {| a : Num, |} }, { a = Promise(Num, 1 + a) })").unwrap();
parse_and_typecheck("Promise({a : Num, b : Num}, { a = Promise(Num,1); b = a + 1})")
.unwrap();
parse_and_typecheck("Promise({a : Num, b : Num}, { a = Promise(Num,true); b = a + 1})")
.unwrap_err();
parse_and_typecheck("Promise({a : Num, b : Bool}, { a = 1; b = Promise(Bool, a) } )")
.unwrap_err();
parse_and_typecheck("Promise({a : Num}, { a = Promise(Num, 1 + a) })").unwrap();
}

#[test]
Expand All @@ -2250,12 +2241,12 @@ mod tests {
.unwrap_err();

// Fields in recursive records are treated in the type environment in the same way as let-bound expressions
parse_and_typecheck("Promise({ {| a : Num, b : Num, |} }, { a = 1; b = 1 + a })").unwrap();
parse_and_typecheck("Promise({a : Num, b : Num}, { a = 1; b = 1 + a })").unwrap();
parse_and_typecheck(
"Promise({ {| f : Num -> Num, |} }, { f = fun x => if isZero x then 1 else 1 + (f (x + (-1)));})"
"Promise({f : Num -> Num}, { f = fun x => if isZero x then 1 else 1 + (f (x + (-1)));})"
).unwrap();
parse_and_typecheck(
"Promise({ {| f : Num -> Num, |} }, { f = fun x => if isZero x then false else 1 + (f (x + (-1)))})"
"Promise({f : Num -> Num}, { f = fun x => if isZero x then false else 1 + (f (x + (-1)))})"
).unwrap_err();
}

Expand Down
15 changes: 8 additions & 7 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ impl fmt::Display for Types {
}
write!(f, ". {}", curr)
}
AbsType::Enum(row) => write!(f, "< (| {}) >", row),
AbsType::StaticRecord(row) => write!(f, "{{ {{| {}}} }}", row),
AbsType::Enum(row) => write!(f, "<{}>", row),
AbsType::StaticRecord(row) => write!(f, "{{{}}}", row),
AbsType::DynRecord(ty) => write!(f, "{{_: {}}}", ty),
AbsType::RowEmpty() => write!(f, " |"),
AbsType::RowEmpty() => Ok(()),
AbsType::RowExtend(Ident(id), ty_opt, tail) => {
write!(f, "{}", id)?;

Expand Down Expand Up @@ -431,10 +431,11 @@ mod test {
assert_format_eq("{_: Str}");
assert_format_eq("{_: (Str -> Str) -> Str}");

assert_format_eq("{ {| x: (Bool -> Bool) -> Bool, y: Bool |} }");
assert_format_eq("{ {| x: Bool, y: Bool, z: Bool | r} }");
assert_format_eq("{x: (Bool -> Bool) -> Bool, y: Bool}");
assert_format_eq("{x: Bool, y: Bool, z: Bool | r}");
assert_format_eq("{x: Bool, y: Bool, z: Bool}");

assert_format_eq("< (| a, b, c, d |) >");
assert_format_eq("< (| tag1, tag2, tag3 | r) >");
assert_format_eq("<a, b, c, d>");
assert_format_eq("<tag1, tag2, tag3 | r>");
}
}

0 comments on commit 6a93ad8

Please sign in to comment.