Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve binop elaboration in synth mode #501

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 152 additions & 127 deletions fathom/src/surface/elaboration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1856,134 +1856,162 @@ impl<'arena> Context<'arena> {

// de-sugar into function application
let (lhs_expr, lhs_type) = self.synth_and_insert_implicit_apps(lhs);
let (rhs_expr, rhs_type) = self.synth_and_insert_implicit_apps(rhs);
let lhs_type = self.elim_env().force(&lhs_type);
let rhs_type = self.elim_env().force(&rhs_type);
let operand_types = Option::zip(lhs_type.match_prim_spine(), rhs_type.match_prim_spine());

let (fun, body_type) = match (op, operand_types) {
(Mul(_), Some(((U8Type, []), (U8Type, [])))) => (U8Mul, U8Type),
(Mul(_), Some(((U16Type, []), (U16Type, [])))) => (U16Mul, U16Type),
(Mul(_), Some(((U32Type, []), (U32Type, [])))) => (U32Mul, U32Type),
(Mul(_), Some(((U64Type, []), (U64Type, [])))) => (U64Mul, U64Type),

(Mul(_), Some(((S8Type, []), (S8Type, [])))) => (S8Mul, S8Type),
(Mul(_), Some(((S16Type, []), (S16Type, [])))) => (S16Mul, S16Type),
(Mul(_), Some(((S32Type, []), (S32Type, [])))) => (S32Mul, S32Type),
(Mul(_), Some(((S64Type, []), (S64Type, [])))) => (S64Mul, S64Type),

(Div(_), Some(((U8Type, []), (U8Type, [])))) => (U8Div, U8Type),
(Div(_), Some(((U16Type, []), (U16Type, [])))) => (U16Div, U16Type),
(Div(_), Some(((U32Type, []), (U32Type, [])))) => (U32Div, U32Type),
(Div(_), Some(((U64Type, []), (U64Type, [])))) => (U64Div, U64Type),

(Div(_), Some(((S8Type, []), (S8Type, [])))) => (S8Div, S8Type),
(Div(_), Some(((S16Type, []), (S16Type, [])))) => (S16Div, S16Type),
(Div(_), Some(((S32Type, []), (S32Type, [])))) => (S32Div, S32Type),
(Div(_), Some(((S64Type, []), (S64Type, [])))) => (S64Div, S64Type),

(Add(_), Some(((U8Type, []), (U8Type, [])))) => (U8Add, U8Type),
(Add(_), Some(((U16Type, []), (U16Type, [])))) => (U16Add, U16Type),
(Add(_), Some(((U32Type, []), (U32Type, [])))) => (U32Add, U32Type),
(Add(_), Some(((U64Type, []), (U64Type, [])))) => (U64Add, U64Type),

(Add(_), Some(((S8Type, []), (S8Type, [])))) => (S8Add, S8Type),
(Add(_), Some(((S16Type, []), (S16Type, [])))) => (S16Add, S16Type),
(Add(_), Some(((S32Type, []), (S32Type, [])))) => (S32Add, S32Type),
(Add(_), Some(((S64Type, []), (S64Type, [])))) => (S64Add, S64Type),

(Add(_), Some(((PosType, []), (U8Type, [])))) => (PosAddU8, PosType),
(Add(_), Some(((PosType, []), (U16Type, [])))) => (PosAddU16, PosType),
(Add(_), Some(((PosType, []), (U32Type, [])))) => (PosAddU32, PosType),
(Add(_), Some(((PosType, []), (U64Type, [])))) => (PosAddU64, PosType),

(Sub(_), Some(((U8Type, []), (U8Type, [])))) => (U8Sub, U8Type),
(Sub(_), Some(((U16Type, []), (U16Type, [])))) => (U16Sub, U16Type),
(Sub(_), Some(((U32Type, []), (U32Type, [])))) => (U32Sub, U32Type),
(Sub(_), Some(((U64Type, []), (U64Type, [])))) => (U64Sub, U64Type),

(Sub(_), Some(((S8Type, []), (S8Type, [])))) => (S8Sub, S8Type),
(Sub(_), Some(((S16Type, []), (S16Type, [])))) => (S16Sub, S16Type),
(Sub(_), Some(((S32Type, []), (S32Type, [])))) => (S32Sub, S32Type),
(Sub(_), Some(((S64Type, []), (S64Type, [])))) => (S64Sub, S64Type),

(Eq(_), Some(((BoolType, []), (BoolType, [])))) => (BoolEq, BoolType),
(Neq(_), Some(((BoolType, []), (BoolType, [])))) => (BoolNeq, BoolType),

(Eq(_), Some(((U8Type, []), (U8Type, [])))) => (U8Eq, BoolType),
(Eq(_), Some(((U16Type, []), (U16Type, [])))) => (U16Eq, BoolType),
(Eq(_), Some(((U32Type, []), (U32Type, [])))) => (U32Eq, BoolType),
(Eq(_), Some(((U64Type, []), (U64Type, [])))) => (U64Eq, BoolType),

(Eq(_), Some(((S8Type, []), (S8Type, [])))) => (S8Eq, BoolType),
(Eq(_), Some(((S16Type, []), (S16Type, [])))) => (S16Eq, BoolType),
(Eq(_), Some(((S32Type, []), (S32Type, [])))) => (S32Eq, BoolType),
(Eq(_), Some(((S64Type, []), (S64Type, [])))) => (S64Eq, BoolType),

(Neq(_), Some(((U8Type, []), (U8Type, [])))) => (U8Neq, BoolType),
(Neq(_), Some(((U16Type, []), (U16Type, [])))) => (U16Neq, BoolType),
(Neq(_), Some(((U32Type, []), (U32Type, [])))) => (U32Neq, BoolType),
(Neq(_), Some(((U64Type, []), (U64Type, [])))) => (U64Neq, BoolType),

(Neq(_), Some(((S8Type, []), (S8Type, [])))) => (S8Neq, BoolType),
(Neq(_), Some(((S16Type, []), (S16Type, [])))) => (S16Neq, BoolType),
(Neq(_), Some(((S32Type, []), (S32Type, [])))) => (S32Neq, BoolType),
(Neq(_), Some(((S64Type, []), (S64Type, [])))) => (S64Neq, BoolType),

(Lt(_), Some(((U8Type, []), (U8Type, [])))) => (U8Lt, BoolType),
(Lt(_), Some(((U16Type, []), (U16Type, [])))) => (U16Lt, BoolType),
(Lt(_), Some(((U32Type, []), (U32Type, [])))) => (U32Lt, BoolType),
(Lt(_), Some(((U64Type, []), (U64Type, [])))) => (U64Lt, BoolType),

(Lt(_), Some(((S8Type, []), (S8Type, [])))) => (S8Lt, BoolType),
(Lt(_), Some(((S16Type, []), (S16Type, [])))) => (S16Lt, BoolType),
(Lt(_), Some(((S32Type, []), (S32Type, [])))) => (S32Lt, BoolType),
(Lt(_), Some(((S64Type, []), (S64Type, [])))) => (S64Lt, BoolType),

(Lte(_), Some(((U8Type, []), (U8Type, [])))) => (U8Lte, BoolType),
(Lte(_), Some(((U16Type, []), (U16Type, [])))) => (U16Lte, BoolType),
(Lte(_), Some(((U32Type, []), (U32Type, [])))) => (U32Lte, BoolType),
(Lte(_), Some(((U64Type, []), (U64Type, [])))) => (U64Lte, BoolType),

(Lte(_), Some(((S8Type, []), (S8Type, [])))) => (S8Lte, BoolType),
(Lte(_), Some(((S16Type, []), (S16Type, [])))) => (S16Lte, BoolType),
(Lte(_), Some(((S32Type, []), (S32Type, [])))) => (S32Lte, BoolType),
(Lte(_), Some(((S64Type, []), (S64Type, [])))) => (S64Lte, BoolType),

(Gt(_), Some(((U8Type, []), (U8Type, [])))) => (U8Gt, BoolType),
(Gt(_), Some(((U16Type, []), (U16Type, [])))) => (U16Gt, BoolType),
(Gt(_), Some(((U32Type, []), (U32Type, [])))) => (U32Gt, BoolType),
(Gt(_), Some(((U64Type, []), (U64Type, [])))) => (U64Gt, BoolType),

(Gt(_), Some(((S8Type, []), (S8Type, [])))) => (S8Gt, BoolType),
(Gt(_), Some(((S16Type, []), (S16Type, [])))) => (S16Gt, BoolType),
(Gt(_), Some(((S32Type, []), (S32Type, [])))) => (S32Gt, BoolType),
(Gt(_), Some(((S64Type, []), (S64Type, [])))) => (S64Gt, BoolType),

(Gte(_), Some(((U8Type, []), (U8Type, [])))) => (U8Gte, BoolType),
(Gte(_), Some(((U16Type, []), (U16Type, [])))) => (U16Gte, BoolType),
(Gte(_), Some(((U32Type, []), (U32Type, [])))) => (U32Gte, BoolType),
(Gte(_), Some(((U64Type, []), (U64Type, [])))) => (U64Gte, BoolType),

(Gte(_), Some(((S8Type, []), (S8Type, [])))) => (S8Gte, BoolType),
(Gte(_), Some(((S16Type, []), (S16Type, [])))) => (S16Gte, BoolType),
(Gte(_), Some(((S32Type, []), (S32Type, [])))) => (S32Gte, BoolType),
(Gte(_), Some(((S64Type, []), (S64Type, [])))) => (S64Gte, BoolType),

let lhs_prim = lhs_type.match_prim_spine();
let (prim, rhs_prim_type, ret_prim_type) = match (op, lhs_prim) {
(Add(_), Some((U8Type, []))) => (U8Add, U8Type, U8Type),
(Add(_), Some((U16Type, []))) => (U16Add, U16Type, U16Type),
(Add(_), Some((U32Type, []))) => (U32Add, U32Type, U32Type),
(Add(_), Some((U64Type, []))) => (U64Add, U64Type, U64Type),

(Add(_), Some((S8Type, []))) => (S8Add, S8Type, S8Type),
(Add(_), Some((S16Type, []))) => (S16Add, S16Type, S16Type),
(Add(_), Some((S32Type, []))) => (S32Add, S32Type, S32Type),
(Add(_), Some((S64Type, []))) => (S64Add, S64Type, S64Type),

(Sub(_), Some((U8Type, []))) => (U8Sub, U8Type, U8Type),
(Sub(_), Some((U16Type, []))) => (U16Sub, U16Type, U16Type),
(Sub(_), Some((U32Type, []))) => (U32Sub, U32Type, U32Type),
(Sub(_), Some((U64Type, []))) => (U64Sub, U64Type, U64Type),

(Sub(_), Some((S8Type, []))) => (S8Sub, S8Type, S8Type),
(Sub(_), Some((S16Type, []))) => (S16Sub, S16Type, S16Type),
(Sub(_), Some((S32Type, []))) => (S32Sub, S32Type, S32Type),
(Sub(_), Some((S64Type, []))) => (S64Sub, S64Type, S64Type),

(Mul(_), Some((U8Type, []))) => (U8Mul, U8Type, U8Type),
(Mul(_), Some((U16Type, []))) => (U16Mul, U16Type, U16Type),
(Mul(_), Some((U32Type, []))) => (U32Mul, U32Type, U32Type),
(Mul(_), Some((U64Type, []))) => (U64Mul, U64Type, U64Type),

(Mul(_), Some((S8Type, []))) => (S8Mul, S8Type, S8Type),
(Mul(_), Some((S16Type, []))) => (S16Mul, S16Type, S16Type),
(Mul(_), Some((S32Type, []))) => (S32Mul, S32Type, S32Type),
(Mul(_), Some((S64Type, []))) => (S64Mul, S64Type, S64Type),

(Div(_), Some((U8Type, []))) => (U8Div, U8Type, U8Type),
(Div(_), Some((U16Type, []))) => (U16Div, U16Type, U16Type),
(Div(_), Some((U32Type, []))) => (U32Div, U32Type, U32Type),
(Div(_), Some((U64Type, []))) => (U64Div, U64Type, U64Type),

(Div(_), Some((S8Type, []))) => (S8Div, S8Type, S8Type),
(Div(_), Some((S16Type, []))) => (S16Div, S16Type, S16Type),
(Div(_), Some((S32Type, []))) => (S32Div, S32Type, S32Type),
(Div(_), Some((S64Type, []))) => (S64Div, S64Type, S64Type),

(Eq(_), Some((BoolType, []))) => (BoolEq, BoolType, BoolType),
(Neq(_), Some((BoolType, []))) => (BoolNeq, BoolType, BoolType),

(Eq(_), Some((U8Type, []))) => (U8Eq, U8Type, BoolType),
(Eq(_), Some((U16Type, []))) => (U16Eq, U16Type, BoolType),
(Eq(_), Some((U32Type, []))) => (U32Eq, U32Type, BoolType),
(Eq(_), Some((U64Type, []))) => (U64Eq, U64Type, BoolType),

(Eq(_), Some((S8Type, []))) => (S8Eq, S8Type, BoolType),
(Eq(_), Some((S16Type, []))) => (S16Eq, S16Type, BoolType),
(Eq(_), Some((S32Type, []))) => (S32Eq, S32Type, BoolType),
(Eq(_), Some((S64Type, []))) => (S64Eq, S64Type, BoolType),

(Neq(_), Some((U8Type, []))) => (U8Neq, U8Type, BoolType),
(Neq(_), Some((U16Type, []))) => (U16Neq, U16Type, BoolType),
(Neq(_), Some((U32Type, []))) => (U32Neq, U32Type, BoolType),
(Neq(_), Some((U64Type, []))) => (U64Neq, U64Type, BoolType),

(Neq(_), Some((S8Type, []))) => (S8Neq, S8Type, BoolType),
(Neq(_), Some((S16Type, []))) => (S16Neq, S16Type, BoolType),
(Neq(_), Some((S32Type, []))) => (S32Neq, S32Type, BoolType),
(Neq(_), Some((S64Type, []))) => (S64Neq, S64Type, BoolType),

(Lt(_), Some((U8Type, []))) => (U8Lt, U8Type, BoolType),
(Lt(_), Some((U16Type, []))) => (U16Lt, U16Type, BoolType),
(Lt(_), Some((U32Type, []))) => (U32Lt, U32Type, BoolType),
(Lt(_), Some((U64Type, []))) => (U64Lt, U64Type, BoolType),

(Lt(_), Some((S8Type, []))) => (S8Lt, S8Type, BoolType),
(Lt(_), Some((S16Type, []))) => (S16Lt, S16Type, BoolType),
(Lt(_), Some((S32Type, []))) => (S32Lt, S32Type, BoolType),
(Lt(_), Some((S64Type, []))) => (S64Lt, S64Type, BoolType),

(Lte(_), Some((U8Type, []))) => (U8Lte, U8Type, BoolType),
(Lte(_), Some((U16Type, []))) => (U16Lte, U16Type, BoolType),
(Lte(_), Some((U32Type, []))) => (U32Lte, U32Type, BoolType),
(Lte(_), Some((U64Type, []))) => (U64Lte, U64Type, BoolType),

(Lte(_), Some((S8Type, []))) => (S8Lte, S8Type, BoolType),
(Lte(_), Some((S16Type, []))) => (S16Lte, S16Type, BoolType),
(Lte(_), Some((S32Type, []))) => (S32Lte, S32Type, BoolType),
(Lte(_), Some((S64Type, []))) => (S64Lte, S64Type, BoolType),

(Gt(_), Some((U8Type, []))) => (U8Gt, U8Type, BoolType),
(Gt(_), Some((U16Type, []))) => (U16Gt, U16Type, BoolType),
(Gt(_), Some((U32Type, []))) => (U32Gt, U32Type, BoolType),
(Gt(_), Some((U64Type, []))) => (U64Gt, U64Type, BoolType),

(Gt(_), Some((S8Type, []))) => (S8Gt, S8Type, BoolType),
(Gt(_), Some((S16Type, []))) => (S16Gt, S16Type, BoolType),
(Gt(_), Some((S32Type, []))) => (S32Gt, S32Type, BoolType),
(Gt(_), Some((S64Type, []))) => (S64Gt, S64Type, BoolType),

(Gte(_), Some((U8Type, []))) => (U8Gte, U8Type, BoolType),
(Gte(_), Some((U16Type, []))) => (U16Gte, U16Type, BoolType),
(Gte(_), Some((U32Type, []))) => (U32Gte, U32Type, BoolType),
(Gte(_), Some((U64Type, []))) => (U64Gte, U64Type, BoolType),

(Gte(_), Some((S8Type, []))) => (S8Gte, S8Type, BoolType),
(Gte(_), Some((S16Type, []))) => (S16Gte, S16Type, BoolType),
(Gte(_), Some((S32Type, []))) => (S32Gte, S32Type, BoolType),
(Gte(_), Some((S64Type, []))) => (S64Gte, S64Type, BoolType),

_ => {
self.push_message(Message::BinOpMismatchedTypes {
range: self.file_range(range),
lhs_range: self.file_range(lhs.range()),
rhs_range: self.file_range(rhs.range()),
op: op.map_range(|range| self.file_range(range)),
lhs: self.pretty_value(&lhs_type),
rhs: self.pretty_value(&rhs_type),
});
return self.synth_reported_error(range);
let (rhs_expr, rhs_type) = self.synth_and_insert_implicit_apps(rhs);
let rhs_type = self.elim_env().force(&rhs_type);
let rhs_prim = rhs_type.match_prim_spine();

let operand_types = Option::zip(lhs_prim, rhs_prim);
let (prim, ret_prim_type) = match (op, operand_types) {
(Add(_), Some(((PosType, []), (U8Type, [])))) => (PosAddU8, PosType),
(Add(_), Some(((PosType, []), (U16Type, [])))) => (PosAddU16, PosType),
(Add(_), Some(((PosType, []), (U32Type, [])))) => (PosAddU32, PosType),
(Add(_), Some(((PosType, []), (U64Type, [])))) => (PosAddU64, PosType),
_ => {
self.push_message(Message::BinOpMismatchedTypes {
range: self.file_range(range),
lhs_range: self.file_range(lhs.range()),
rhs_range: self.file_range(rhs.range()),
op: op.map_range(|range| self.file_range(range)),
lhs: self.pretty_value(&lhs_type),
rhs: self.pretty_value(&rhs_type),
});
return self.synth_reported_error(range);
}
};

let ret_type = Spanned::empty(Arc::new(Value::prim(ret_prim_type, [])));

let fun_head = core::Term::Prim(self.file_range(op.range()).into(), prim);
let fun_app = core::Term::FunApp(
self.file_range(range).into(),
Plicity::Explicit,
self.scope.to_scope(core::Term::FunApp(
Span::merge(&lhs_expr.span(), &rhs_expr.span()),
Plicity::Explicit,
self.scope.to_scope(fun_head),
self.scope.to_scope(lhs_expr),
)),
self.scope.to_scope(rhs_expr),
);

return (fun_app, ret_type);
}
};

let fun_head = core::Term::Prim(self.file_range(op.range()).into(), fun);
let rhs_type = Spanned::empty(Arc::new(Value::prim(rhs_prim_type, [])));
let ret_type = Spanned::empty(Arc::new(Value::prim(ret_prim_type, [])));

let rhs_expr = self.check(rhs, &rhs_type);

let fun_head = core::Term::Prim(self.file_range(op.range()).into(), prim);
let fun_app = core::Term::FunApp(
self.file_range(range).into(),
Plicity::Explicit,
Expand All @@ -1996,11 +2024,8 @@ impl<'arena> Context<'arena> {
self.scope.to_scope(rhs_expr),
);

// TODO: Maybe it would be good to reuse lhs_type here if body_type is the same
(
fun_app,
Spanned::empty(Arc::new(Value::prim(body_type, []))),
)
// TODO: Maybe it would be good to reuse lhs_type here if ret_type is the same
(fun_app, ret_type)
}

fn check_bin_op(
Expand Down
Loading