From d214b3875536a4243ccd0919af375fccda6344f3 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Tue, 5 Apr 2022 19:14:35 -0400 Subject: [PATCH] interp/validity: enforce Scalar::Initialized --- .../src/interpret/validity.rs | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs index e5fd182f8b3a9..349806d997945 100644 --- a/compiler/rustc_const_eval/src/interpret/validity.rs +++ b/compiler/rustc_const_eval/src/interpret/validity.rs @@ -629,11 +629,24 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, ' op: &OpTy<'tcx, M::PointerTag>, scalar_layout: ScalarAbi, ) -> InterpResult<'tcx> { - if scalar_layout.valid_range(self.ecx).is_full_for(op.layout.size) { + // We check `is_full_range` in a slightly complicated way because *if* we are checking + // number validity, then we want to ensure that `Scalar::Initialized` is indeed initialized, + // i.e. that we go over the `check_init` below. + let is_full_range = match scalar_layout { + ScalarAbi::Initialized { valid_range, .. } => { + if M::enforce_number_validity(self.ecx) { + false // not "full" since uninit is not accepted + } else { + valid_range.is_full_for(op.layout.size) + } + } + ScalarAbi::Union { .. } => true, + }; + if is_full_range { // Nothing to check return Ok(()); } - // At least one value is excluded. + // We have something to check. let valid_range = scalar_layout.valid_range(self.ecx); let WrappingRange { start, end } = valid_range; let max_value = op.layout.size.unsigned_int_max(); @@ -647,9 +660,11 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, ' expected { "something {}", wrapping_range_format(valid_range, max_value) }, ); let bits = match value.try_to_int() { + Ok(int) => int.assert_bits(op.layout.size), Err(_) => { // So this is a pointer then, and casting to an int failed. // Can only happen during CTFE. + // We support 2 kinds of ranges here: full range, and excluding zero. if start == 1 && end == max_value { // Only null is the niche. So make sure the ptr is NOT null. if self.ecx.scalar_may_be_null(value) { @@ -660,7 +675,11 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, ' wrapping_range_format(valid_range, max_value) } ) + } else { + return Ok(()); } + } else if scalar_layout.valid_range(self.ecx).is_full_for(op.layout.size) { + // Easy. (This is reachable if `enforce_number_validity` is set.) return Ok(()); } else { // Conservatively, we reject, because the pointer *could* have a bad @@ -674,9 +693,8 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, ' ) } } - Ok(int) => int.assert_bits(op.layout.size), }; - // Now compare. This is slightly subtle because this is a special "wrap-around" range. + // Now compare. if valid_range.contains(bits) { Ok(()) } else {