Skip to content

Commit

Permalink
Properly infer types with type casts
Browse files Browse the repository at this point in the history
  • Loading branch information
lowr committed Jul 12, 2023
1 parent 75ac37f commit 074488b
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 28 deletions.
34 changes: 23 additions & 11 deletions crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
//! to certain types. To record this, we use the union-find implementation from
//! the `ena` crate, which is extracted from rustc.

mod cast;
pub(crate) mod closure;
mod coerce;
mod expr;
mod mutability;
mod pat;
mod path;
pub(crate) mod unify;

use std::{convert::identity, ops::Index};

use chalk_ir::{
Expand Down Expand Up @@ -60,15 +69,8 @@ pub use coerce::could_coerce;
#[allow(unreachable_pub)]
pub use unify::could_unify;

pub(crate) use self::closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};

pub(crate) mod unify;
mod path;
mod expr;
mod pat;
mod coerce;
pub(crate) mod closure;
mod mutability;
use cast::CastCheck;
pub(crate) use closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};

/// The entry point of type inference.
pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {
Expand Down Expand Up @@ -508,6 +510,8 @@ pub(crate) struct InferenceContext<'a> {
diverges: Diverges,
breakables: Vec<BreakableContext>,

deferred_cast_checks: Vec<CastCheck>,

// fields related to closure capture
current_captures: Vec<CapturedItemWithoutTy>,
current_closure: Option<ClosureId>,
Expand Down Expand Up @@ -582,7 +586,8 @@ impl<'a> InferenceContext<'a> {
resolver,
diverges: Diverges::Maybe,
breakables: Vec::new(),
current_captures: vec![],
deferred_cast_checks: Vec::new(),
current_captures: Vec::new(),
current_closure: None,
deferred_closures: FxHashMap::default(),
closure_dependencies: FxHashMap::default(),
Expand All @@ -594,7 +599,7 @@ impl<'a> InferenceContext<'a> {
// used this function for another workaround, mention it here. If you really need this function and believe that
// there is no problem in it being `pub(crate)`, remove this comment.
pub(crate) fn resolve_all(self) -> InferenceResult {
let InferenceContext { mut table, mut result, .. } = self;
let InferenceContext { mut table, mut result, deferred_cast_checks, .. } = self;
// Destructure every single field so whenever new fields are added to `InferenceResult` we
// don't forget to handle them here.
let InferenceResult {
Expand Down Expand Up @@ -622,6 +627,13 @@ impl<'a> InferenceContext<'a> {

table.fallback_if_possible();

// Comment from rustc:
// Even though coercion casts provide type hints, we check casts after fallback for
// backwards compatibility. This makes fallback a stronger type hint than a cast coercion.
for cast in deferred_cast_checks {
cast.check(&mut table);
}

// FIXME resolve obligations as well (use Guidance if necessary)
table.resolve_obligations_as_possible();

Expand Down
46 changes: 46 additions & 0 deletions crates/hir-ty/src/infer/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//! Type cast logic. Basically coercion + additional casts.

use crate::{infer::unify::InferenceTable, Interner, Ty, TyExt, TyKind};

#[derive(Clone, Debug)]
pub(super) struct CastCheck {
expr_ty: Ty,
cast_ty: Ty,
}

impl CastCheck {
pub(super) fn new(expr_ty: Ty, cast_ty: Ty) -> Self {
Self { expr_ty, cast_ty }
}

pub(super) fn check(self, table: &mut InferenceTable<'_>) {
// FIXME: This function currently only implements the bits that influence the type
// inference. We should return the adjustments on success and report diagnostics on error.
let expr_ty = table.resolve_ty_shallow(&self.expr_ty);
let cast_ty = table.resolve_ty_shallow(&self.cast_ty);

if expr_ty.contains_unknown() || cast_ty.contains_unknown() {
return;
}

if table.coerce(&expr_ty, &cast_ty).is_ok() {
return;
}

if check_ref_to_ptr_cast(expr_ty, cast_ty, table) {
// Note that this type of cast is actually split into a coercion to a
// pointer type and a cast:
// &[T; N] -> *[T; N] -> *T
return;
}

// FIXME: Check other kinds of non-coercion casts and report error if any?
}
}

fn check_ref_to_ptr_cast(expr_ty: Ty, cast_ty: Ty, table: &mut InferenceTable<'_>) -> bool {
let Some((expr_inner_ty, _, _)) = expr_ty.as_reference() else { return false; };
let Some((cast_inner_ty, _)) = cast_ty.as_raw_ptr() else { return false; };
let TyKind::Array(expr_elt_ty, _) = expr_inner_ty.kind(Interner) else { return false; };
table.coerce(expr_elt_ty, cast_inner_ty).is_ok()
}
18 changes: 5 additions & 13 deletions crates/hir-ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ use crate::{
};

use super::{
coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges, Expectation,
InferenceContext, InferenceDiagnostic, TypeMismatch,
cast::CastCheck, coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges,
Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch,
};

impl InferenceContext<'_> {
Expand Down Expand Up @@ -574,16 +574,8 @@ impl InferenceContext<'_> {
}
Expr::Cast { expr, type_ref } => {
let cast_ty = self.make_ty(type_ref);
// FIXME: propagate the "castable to" expectation
let inner_ty = self.infer_expr_no_expect(*expr);
match (inner_ty.kind(Interner), cast_ty.kind(Interner)) {
(TyKind::Ref(_, _, inner), TyKind::Raw(_, cast)) => {
// FIXME: record invalid cast diagnostic in case of mismatch
self.unify(inner, cast);
}
// FIXME check the other kinds of cast...
_ => (),
}
let expr_ty = self.infer_expr(*expr, &Expectation::Castable(cast_ty.clone()));
self.deferred_cast_checks.push(CastCheck::new(expr_ty, cast_ty.clone()));
cast_ty
}
Expr::Ref { expr, rawness, mutability } => {
Expand Down Expand Up @@ -1592,7 +1584,7 @@ impl InferenceContext<'_> {
output: Ty,
inputs: Vec<Ty>,
) -> Vec<Ty> {
if let Some(expected_ty) = expected_output.to_option(&mut self.table) {
if let Some(expected_ty) = expected_output.only_has_type(&mut self.table) {
self.table.fudge_inference(|table| {
if table.try_unify(&expected_ty, &output).is_ok() {
table.resolve_with_fallback(inputs, &|var, kind, _, _| match kind {
Expand Down
20 changes: 20 additions & 0 deletions crates/hir-ty/src/tests/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1978,3 +1978,23 @@ fn x(a: [i32; 4]) {
"#,
);
}

#[test]
fn dont_unify_on_casts() {
// #15246
check_types(
r#"
fn unify(_: [bool; 1]) {}
fn casted(_: *const bool) {}
fn default<T>() -> T { loop {} }
fn test() {
let foo = default();
//^^^ [bool; 1]
casted(&foo as *const _);
unify(foo);
}
"#,
);
}
22 changes: 18 additions & 4 deletions crates/hir-ty/src/tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3513,7 +3513,6 @@ fn func() {
);
}

// FIXME
#[test]
fn castable_to() {
check_infer(
Expand All @@ -3538,10 +3537,10 @@ fn func() {
120..122 '{}': ()
138..184 '{ ...0]>; }': ()
148..149 'x': Box<[i32; 0]>
152..160 'Box::new': fn new<[{unknown}; 0]>([{unknown}; 0]) -> Box<[{unknown}; 0]>
152..164 'Box::new([])': Box<[{unknown}; 0]>
152..160 'Box::new': fn new<[i32; 0]>([i32; 0]) -> Box<[i32; 0]>
152..164 'Box::new([])': Box<[i32; 0]>
152..181 'Box::n...2; 0]>': Box<[i32; 0]>
161..163 '[]': [{unknown}; 0]
161..163 '[]': [i32; 0]
"#]],
);
}
Expand Down Expand Up @@ -3577,6 +3576,21 @@ fn f<T>(t: Ark<T>) {
);
}

#[test]
fn ref_to_array_to_ptr_cast() {
check_types(
r#"
fn default<T>() -> T { loop {} }
fn foo() {
let arr = [default()];
//^^^ [i32; 1]
let ref_to_arr = &arr;
let casted = ref_to_arr as *const i32;
}
"#,
);
}

#[test]
fn const_dependent_on_local() {
check_types(
Expand Down

0 comments on commit 074488b

Please sign in to comment.