Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shallow type inference #273

Merged
merged 6 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::error::{Error, EvalError, IOError};
use crate::identifier::Ident;
use crate::parser::extended::{ExtendedParser, ExtendedTerm};
use crate::term::{RichTerm, Term};
use crate::types::{AbsType, Types};
use crate::types::Types;
use crate::{eval, transformations, typecheck};
use simple_counter::*;
use std::ffi::{OsStr, OsString};
Expand Down Expand Up @@ -146,7 +146,11 @@ impl REPL for REPLImpl {
let term = self.cache.parse_nocache(file_id)?;
typecheck::type_check_in_env(&term, &self.type_env, &self.cache)?;

Ok(typecheck::apparent_type(term.as_ref()).unwrap_or(Types(AbsType::Dyn())))
Ok(typecheck::apparent_type(
term.as_ref(),
Some(&typecheck::Envs::from_global(&self.type_env)),
)
.into())
}

fn query(&mut self, exp: &str) -> Result<Term, Error> {
Expand Down
154 changes: 119 additions & 35 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use crate::term::{BinaryOp, MetaValue, RichTerm, StrChunk, Term, UnaryOp};
use crate::types::{AbsType, Types};
use crate::{mk_tyw_arrow, mk_tyw_enum, mk_tyw_enum_row, mk_tyw_record, mk_tyw_row};
use std::collections::{HashMap, HashSet};
use std::convert::TryInto;

/// Error during the unification of two row types.
#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -361,9 +362,7 @@ impl<'a> Envs<'a> {
.map(|(id, (rc, _))| {
(
id.clone(),
apparent_type(rc.borrow().body.as_ref())
.map(to_typewrapper)
.unwrap_or_else(mk_typewrapper::dynamic),
to_typewrapper(apparent_type(rc.borrow().body.as_ref(), None).into()),
)
})
.collect()
Expand All @@ -376,16 +375,12 @@ impl<'a> Envs<'a> {

match term.as_ref() {
Term::Record(bindings) | Term::RecRecord(bindings) => {
let ext = bindings.into_iter().map(|(id, t)| {
(
id.clone(),
apparent_type(t.as_ref())
.map(to_typewrapper)
.unwrap_or_else(mk_typewrapper::dynamic),
)
});
for (id, t) in bindings {
let tyw: TypeWrapper =
apparent_type(t.as_ref(), Some(&Envs::from_global(env))).into();
env.insert(id.clone(), tyw);
}

env.extend(ext);
Ok(())
}
t => Err(eval::EnvBuildError::NotARecord(RichTerm::new(
Expand All @@ -399,9 +394,7 @@ impl<'a> Envs<'a> {
pub fn env_add(env: &mut Environment, id: Ident, rt: &RichTerm) {
env.insert(
id,
apparent_type(rt.as_ref())
.map(to_typewrapper)
.unwrap_or_else(mk_typewrapper::dynamic),
to_typewrapper(apparent_type(rt.as_ref(), Some(&Envs::from_global(env))).into()),
);
}

Expand Down Expand Up @@ -555,7 +548,7 @@ fn type_check_(
.map_err(|err| err.to_typecheck_err(state, &rt.pos))
}
Term::Let(x, re, rt) => {
let ty_let = binding_type(re.as_ref(), state.table, strict);
let ty_let = binding_type(re.as_ref(), &envs, state.table, strict);
type_check_(state, envs.clone(), strict, re, ty_let.clone())?;

// TODO move this up once lets are rec
Expand Down Expand Up @@ -615,11 +608,10 @@ fn type_check_(
// For recursive records, we look at the apparent type of each field and bind it in
// env before actually typechecking the content of fields
if let Term::RecRecord(_) = t.as_ref() {
envs.local.extend(
stat_map.iter().map(|(id, rt)| {
(id.clone(), binding_type(rt.as_ref(), state.table, strict))
}),
);
for (id, rt) in stat_map {
let tyw = binding_type(rt.as_ref(), &envs, state.table, strict);
envs.insert(id.clone(), tyw);
}
}

let root_ty = if let TypeWrapper::Ptr(p) = ty {
Expand Down Expand Up @@ -734,36 +726,95 @@ fn type_check_(
/// return `Dyn`.
/// * in strict mode, we will typecheck `bound_exp`: return a new unification variable to be
/// associated to `bound_exp`.
fn binding_type(t: &Term, table: &mut UnifTable, strict: bool) -> TypeWrapper {
match apparent_type(t) {
Some(ty) => to_typewrapper(ty),
None if strict => TypeWrapper::Ptr(new_var(table)),
None => mk_typewrapper::dynamic(),
fn binding_type(t: &Term, envs: &Envs, table: &mut UnifTable, strict: bool) -> TypeWrapper {
let ty_apt = apparent_type(t, Some(envs));

match ty_apt {
ApparentType::Approximated(_) if strict => TypeWrapper::Ptr(new_var(table)),
ty_apt => ty_apt.into(),
}
}

/// Different kinds of apparent types (see [`apparent_type`](fn.apparent_type.html)).
///
/// Indicate the nature of an apparent type. In particular, when in strict mode, the typechecker
/// throws away approximations as it can do better and infer the actual type of an expression by
/// generating a fresh unification variable. In non-strict mode, however, the approximation is the
/// best we can do. This type allows the caller of `apparent_type` to determine which situation it
/// is.
pub enum ApparentType {
/// The apparent type is given by a user-provided annotation, such as an `Assume`, a `Promise`,
/// or a metavalue.
Annotated(Types),
/// The apparent type has been inferred from a simple expression.
Inferred(Types),
/// The term is a variable and its type was retrieved from the typing environment.
FromEnv(TypeWrapper),
/// The apparent type wasn't trivial to determine, and an approximation (most of the time,
/// `Dyn`) has been returned.
Approximated(Types),
}

impl Into<Types> for ApparentType {
fn into(self) -> Types {
match self {
ApparentType::Annotated(ty)
| ApparentType::Inferred(ty)
| ApparentType::Approximated(ty) => ty,
ApparentType::FromEnv(tyw) => tyw.try_into().ok().unwrap_or(Types(AbsType::Dyn())),
}
}
}

impl Into<TypeWrapper> for ApparentType {
fn into(self) -> TypeWrapper {
match self {
ApparentType::Annotated(ty)
| ApparentType::Inferred(ty)
| ApparentType::Approximated(ty) => to_typewrapper(ty),
ApparentType::FromEnv(tyw) => tyw,
}
}
}

/// Determine the apparent type of a let-bound expression.
///
/// When a let-binding `let x = bound_exp in body` is processed, the type of `bound_exp` must be
/// determined to be associated to the bound variable `x` in the typing environment (`typed_vars`).
/// determined in order to be bound to the variable `x` in the typing environment.
/// Then, future occurrences of `x` can be given this type when used in a `Promise` block.
///
/// The role of `apparent_type` is precisely to determine the type of `bound_exp`:
/// - if `bound_exp` is annotated by an `Assume` or a `Promise`, return the user-provided type.
/// - Otherwise, `None` is returned.
pub fn apparent_type(t: &Term) -> Option<Types> {
/// - if `bound_exp` is annotated by an `Assume`, a `Promise` or a metavalue, return the
/// user-provided type.
/// - if `bound_exp` is a constant (string, number, boolean or symbol) which type can be deduced
/// directly without unfolding the expression further, return the corresponding exact type.
/// - if `bound_exp` is a list, return `List Dyn`.
/// - Otherwise, return an approximation of the type (currently `Dyn`, but could be more precise in
/// the future, such as `Dyn -> Dyn` for functions, `{ | Dyn}` for records, and so on).
pub fn apparent_type(t: &Term, envs: Option<&Envs>) -> ApparentType {
match t {
Term::Assume(ty, _, _) | Term::Promise(ty, _, _) => Some(ty.clone()),
Term::Assume(ty, _, _) | Term::Promise(ty, _, _) => ApparentType::Annotated(ty.clone()),
Term::MetaValue(MetaValue {
contract: Some((ty, _)),
..
}) => Some(ty.clone()),
}) => ApparentType::Annotated(ty.clone()),
Term::MetaValue(MetaValue {
contract: None,
value: Some(v),
..
}) => apparent_type(v.as_ref()),
_ => None,
}) => apparent_type(v.as_ref(), envs),
Term::Num(_) => ApparentType::Inferred(Types(AbsType::Num())),
Term::Bool(_) => ApparentType::Inferred(Types(AbsType::Bool())),
Term::Sym(_) => ApparentType::Inferred(Types(AbsType::Sym())),
Term::Str(_) | Term::StrChunks(_) => ApparentType::Inferred(Types(AbsType::Str())),
Term::List(_) => {
ApparentType::Inferred(Types(AbsType::List(Box::new(Types(AbsType::Dyn())))))
}
Term::Var(id) => envs
.and_then(|envs| envs.get(id))
.map(ApparentType::FromEnv)
.unwrap_or(ApparentType::Approximated(Types(AbsType::Dyn()))),
_ => ApparentType::Approximated(Types(AbsType::Dyn())),
}
}

Expand All @@ -779,6 +830,23 @@ pub enum TypeWrapper {
Ptr(usize),
}

impl std::convert::TryInto<Types> for TypeWrapper {
type Error = ();

fn try_into(self) -> Result<Types, ()> {
match self {
TypeWrapper::Concrete(ty) => {
let converted: AbsType<Box<Types>> = ty.try_map(|tyw_boxed| {
let ty: Types = (*tyw_boxed).try_into()?;
Ok(Box::new(ty))
})?;
Ok(Types(converted))
}
_ => Err(()),
}
}
}

impl TypeWrapper {
/// Substitute all the occurrences of a type variable for a typewrapper.
pub fn subst(self, id: Ident, to: TypeWrapper) -> TypeWrapper {
Expand Down Expand Up @@ -1794,6 +1862,7 @@ mod tests {
use crate::parser::lexer;
use crate::term::make as mk_term;
use crate::transformations::transform;
use assert_matches::assert_matches;
use codespan::Files;

use crate::parser;
Expand Down Expand Up @@ -2144,7 +2213,7 @@ mod tests {
fn seq() {
parse_and_typecheck("%seq% false 1 : Num").unwrap();
parse_and_typecheck("(fun x y => %seq% x y) : forall a. (forall b. a -> b -> b)").unwrap();
parse_and_typecheck("let xDyn = false in let yDyn = 1 in (%seq% xDyn yDyn : Dyn)").unwrap();
parse_and_typecheck("let xDyn = if false then true else false in let yDyn = 1 + 1 in (%seq% xDyn yDyn : Dyn)").unwrap();
}

#[test]
Expand Down Expand Up @@ -2299,4 +2368,19 @@ mod tests {
parse_and_typecheck("{gen_ = fun acc x => if x == 0 then acc else gen_ (acc @ [x]) (x - 1)}.gen_ : List Num -> Num -> List Num").unwrap();
parse_and_typecheck("{f = fun x => f x}.f : forall a. a -> a").unwrap();
}

#[test]
fn shallow_type_inference() {
parse_and_typecheck("let x = 1 in (x + 1 : Num)").unwrap();
assert_matches!(
parse_and_typecheck("let x = (1 + 1) in (x + 1 : Num)"),
Err(TypecheckError::TypeMismatch(..))
);

parse_and_typecheck("let x = \"a\" in (x ++ \"a\" : Str)").unwrap();
parse_and_typecheck("let x = \"a#{\"some str inside\"}\" in (x ++ \"a\" : Str)").unwrap();

parse_and_typecheck("let x = false in (x || true : Bool)").unwrap();
parse_and_typecheck("let x = false in let y = x in let z = y in (z : Bool)").unwrap();
}
}
58 changes: 32 additions & 26 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,44 @@ pub enum AbsType<Ty> {
}

impl<Ty> AbsType<Ty> {
pub fn map<To, F: FnMut(Ty) -> To>(self, mut f: F) -> AbsType<To> {
pub fn try_map<To, F, E>(self, mut f: F) -> Result<AbsType<To>, E>
where
F: FnMut(Ty) -> Result<To, E>,
{
match self {
AbsType::Dyn() => AbsType::Dyn(),
AbsType::Num() => AbsType::Num(),
AbsType::Bool() => AbsType::Bool(),
AbsType::Str() => AbsType::Str(),
AbsType::Sym() => AbsType::Sym(),
AbsType::Flat(t) => AbsType::Flat(t),
AbsType::Arrow(s, t) => {
let fs = f(s);
let ft = f(t);

AbsType::Arrow(fs, ft)
}
AbsType::Var(i) => AbsType::Var(i),
AbsType::Forall(i, t) => {
let ft = f(t);

AbsType::Forall(i, ft)
}
AbsType::RowEmpty() => AbsType::RowEmpty(),
AbsType::Dyn() => Ok(AbsType::Dyn()),
AbsType::Num() => Ok(AbsType::Num()),
AbsType::Bool() => Ok(AbsType::Bool()),
AbsType::Str() => Ok(AbsType::Str()),
AbsType::Sym() => Ok(AbsType::Sym()),
AbsType::Flat(t) => Ok(AbsType::Flat(t)),
AbsType::Arrow(s, t) => Ok(AbsType::Arrow(f(s)?, f(t)?)),
AbsType::Var(i) => Ok(AbsType::Var(i)),
AbsType::Forall(i, t) => Ok(AbsType::Forall(i, f(t)?)),
AbsType::RowEmpty() => Ok(AbsType::RowEmpty()),
AbsType::RowExtend(id, t1, t2) => {
let t2_mapped = f(t2);
AbsType::RowExtend(id, t1.map(f), t2_mapped)
let t1_mapped = match t1 {
Some(ty) => Some(f(ty)?),
None => None,
};

Ok(AbsType::RowExtend(id, t1_mapped, f(t2)?))
}
AbsType::Enum(t) => AbsType::Enum(f(t)),
AbsType::StaticRecord(t) => AbsType::StaticRecord(f(t)),
AbsType::DynRecord(t) => AbsType::DynRecord(f(t)),
AbsType::List(t) => AbsType::List(f(t)),
AbsType::Enum(t) => Ok(AbsType::Enum(f(t)?)),
AbsType::StaticRecord(t) => Ok(AbsType::StaticRecord(f(t)?)),
AbsType::DynRecord(t) => Ok(AbsType::DynRecord(f(t)?)),
AbsType::List(t) => Ok(AbsType::List(f(t)?)),
}
}

pub fn map<To, F>(self, mut f: F) -> AbsType<To>
where
F: FnMut(Ty) -> To,
{
let f_lift = |ty: Ty| -> Result<To, ()> { Ok(f(ty)) };
self.try_map(f_lift).unwrap()
}

/// Determine if a type is a row type.
pub fn is_row_type(&self) -> bool {
match self {
Expand Down