From 216dab3b8e5dd47ce5c2500f2c68eac4ff931d33 Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Sat, 30 Mar 2024 13:15:35 +0100 Subject: [PATCH] Use contiguous indices for enum variants --- .../diagnostics/match_check/pat_analysis.rs | 83 +++++++++++++++---- 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs index f45beb4c92bf..40878de8c78a 100644 --- a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs +++ b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs @@ -3,7 +3,7 @@ use std::fmt; use tracing::debug; -use hir_def::{DefWithBodyId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId}; +use hir_def::{DefWithBodyId, EnumId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId}; use rustc_hash::FxHashMap; use rustc_pattern_analysis::{ constructor::{Constructor, ConstructorSet, VariantVisibility}, @@ -36,6 +36,32 @@ pub(crate) type WitnessPat<'p> = rustc_pattern_analysis::pat::WitnessPat Option { + // Find the index of this variant in the list of variants. + db.enum_data(eid) + .variants + .iter() + .map(|(evid, _name)| *evid) + .enumerate() + .find(|(_, evid)| *evid == target_evid) + .map(|(i, _)| EnumVariantContiguousIndex(i)) + } + + fn to_enum_variant_id(self, db: &dyn HirDatabase, eid: EnumId) -> EnumVariantId { + db.enum_data(eid).variants[self.0].0 + } +} + #[derive(Clone)] pub(crate) struct MatchCheckCtx<'p> { module: ModuleId, @@ -89,9 +115,18 @@ impl<'p> MatchCheckCtx<'p> { } } - fn variant_id_for_adt(ctor: &Constructor, adt: hir_def::AdtId) -> Option { + fn variant_id_for_adt( + db: &'p dyn HirDatabase, + ctor: &Constructor, + adt: hir_def::AdtId, + ) -> Option { match ctor { - &Variant(id) => Some(id.into()), + Variant(id) => { + let hir_def::AdtId::EnumId(eid) = adt else { + panic!("bad constructor {ctor:?} for adt {adt:?}") + }; + Some(id.to_enum_variant_id(db, eid).into()) + } Struct | UnionField => match adt { hir_def::AdtId::EnumId(_) => None, hir_def::AdtId::StructId(id) => Some(id.into()), @@ -175,19 +210,37 @@ impl<'p> MatchCheckCtx<'p> { ctor = Struct; arity = 1; } - &TyKind::Adt(adt, _) => { + &TyKind::Adt(AdtId(adt), _) => { ctor = match pat.kind.as_ref() { - PatKind::Leaf { .. } if matches!(adt.0, hir_def::AdtId::UnionId(_)) => { + PatKind::Leaf { .. } if matches!(adt, hir_def::AdtId::UnionId(_)) => { UnionField } PatKind::Leaf { .. } => Struct, - PatKind::Variant { enum_variant, .. } => Variant(*enum_variant), + PatKind::Variant { enum_variant, .. } => { + if let hir_def::AdtId::EnumId(eid) = adt { + if let Some(id) = + EnumVariantContiguousIndex::from_enum_variant_id( + self.db, + eid, + *enum_variant, + ) + { + Variant(id) + } else { + never!(); + Wildcard + } + } else { + never!(); + Wildcard + } + } _ => { never!(); Wildcard } }; - let variant = Self::variant_id_for_adt(&ctor, adt.0).unwrap(); + let variant = Self::variant_id_for_adt(self.db, &ctor, adt).unwrap(); arity = variant.variant_data(self.db.upcast()).fields().len(); } _ => { @@ -239,7 +292,7 @@ impl<'p> MatchCheckCtx<'p> { PatKind::Deref { subpattern: subpatterns.next().unwrap() } } TyKind::Adt(adt, substs) => { - let variant = Self::variant_id_for_adt(pat.ctor(), adt.0).unwrap(); + let variant = Self::variant_id_for_adt(self.db, pat.ctor(), adt.0).unwrap(); let subpatterns = self .list_variant_fields(pat.ty(), variant) .zip(subpatterns) @@ -277,7 +330,7 @@ impl<'p> MatchCheckCtx<'p> { impl<'p> PatCx for MatchCheckCtx<'p> { type Error = (); type Ty = Ty; - type VariantIdx = EnumVariantId; + type VariantIdx = EnumVariantContiguousIndex; type StrLit = Void; type ArmData = (); type PatData = PatData<'p>; @@ -303,7 +356,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> { // patterns. If we're here we can assume this is a box pattern. 1 } else { - let variant = Self::variant_id_for_adt(ctor, adt).unwrap(); + let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap(); variant.variant_data(self.db.upcast()).fields().len() } } @@ -343,7 +396,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> { let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone(); single(subst_ty) } else { - let variant = Self::variant_id_for_adt(ctor, adt).unwrap(); + let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap(); let (adt, _) = ty.as_adt().unwrap(); let adt_is_local = @@ -421,7 +474,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> { ConstructorSet::NoConstructors } else { let mut variants = FxHashMap::default(); - for &(variant, _) in enum_data.variants.iter() { + for (i, &(variant, _)) in enum_data.variants.iter().enumerate() { let is_uninhabited = is_enum_variant_uninhabited_from(variant, subst, cx.module, cx.db); let visibility = if is_uninhabited { @@ -429,7 +482,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> { } else { VariantVisibility::Visible }; - variants.insert(variant, visibility); + variants.insert(EnumVariantContiguousIndex(i), visibility); } ConstructorSet::Variants { @@ -453,10 +506,10 @@ impl<'p> PatCx for MatchCheckCtx<'p> { f: &mut fmt::Formatter<'_>, pat: &rustc_pattern_analysis::pat::DeconstructedPat, ) -> fmt::Result { + let db = pat.data().db; let variant = - pat.ty().as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(pat.ctor(), adt)); + pat.ty().as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(db, pat.ctor(), adt)); - let db = pat.data().db; if let Some(variant) = variant { match variant { VariantId::EnumVariantId(v) => {