Skip to content

Commit f83c72f

Browse files
committed
fix: branches-sharing-code suggests wrongly on const and static
1 parent 45168a7 commit f83c72f

File tree

4 files changed

+388
-28
lines changed

4 files changed

+388
-28
lines changed

clippy_lints/src/ifs/branches_sharing_code.rs

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use clippy_utils::{
99
use core::iter;
1010
use core::ops::ControlFlow;
1111
use rustc_errors::Applicability;
12-
use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, LetStmt, Node, Stmt, StmtKind, intravisit};
12+
use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, ItemKind, LetStmt, Node, Stmt, StmtKind, UseKind, intravisit};
1313
use rustc_lint::LateContext;
1414
use rustc_span::hygiene::walk_chain;
1515
use rustc_span::source_map::SourceMap;
@@ -108,6 +108,7 @@ struct BlockEq {
108108
/// The name and id of every local which can be moved at the beginning and the end.
109109
moved_locals: Vec<(HirId, Symbol)>,
110110
}
111+
111112
impl BlockEq {
112113
fn start_span(&self, b: &Block<'_>, sm: &SourceMap) -> Option<Span> {
113114
match &b.stmts[..self.start_end_eq] {
@@ -129,20 +130,33 @@ impl BlockEq {
129130
}
130131

131132
/// If the statement is a local, checks if the bound names match the expected list of names.
132-
fn eq_binding_names(s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool {
133-
if let StmtKind::Let(l) = s.kind {
134-
let mut i = 0usize;
135-
let mut res = true;
136-
l.pat.each_binding_or_first(&mut |_, _, _, name| {
137-
if names.get(i).is_some_and(|&(_, n)| n == name.name) {
138-
i += 1;
139-
} else {
140-
res = false;
141-
}
142-
});
143-
res && i == names.len()
144-
} else {
145-
false
133+
fn eq_binding_names(cx: &LateContext<'_>, s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool {
134+
match s.kind {
135+
StmtKind::Let(l) => {
136+
let mut i = 0usize;
137+
let mut res = true;
138+
l.pat.each_binding_or_first(&mut |_, _, _, name| {
139+
if names.get(i).is_some_and(|&(_, n)| n == name.name) {
140+
i += 1;
141+
} else {
142+
res = false;
143+
}
144+
});
145+
res && i == names.len()
146+
},
147+
StmtKind::Item(item_id)
148+
if let [(_, name)] = names
149+
&& let item = cx.tcx.hir_item(item_id)
150+
&& let ItemKind::Static(_, ident, ..)
151+
| ItemKind::Const(ident, ..)
152+
| ItemKind::Fn { ident, .. }
153+
| ItemKind::TyAlias(ident, ..)
154+
| ItemKind::Use(_, UseKind::Single(ident))
155+
| ItemKind::Mod(ident, _) = item.kind =>
156+
{
157+
*name == ident.name
158+
},
159+
_ => false,
146160
}
147161
}
148162

@@ -164,6 +178,7 @@ fn modifies_any_local<'tcx>(cx: &LateContext<'tcx>, s: &'tcx Stmt<'_>, locals: &
164178
/// Checks if the given statement should be considered equal to the statement in the same
165179
/// position for each block.
166180
fn eq_stmts(
181+
cx: &LateContext<'_>,
167182
stmt: &Stmt<'_>,
168183
blocks: &[&Block<'_>],
169184
get_stmt: impl for<'a> Fn(&'a Block<'a>) -> Option<&'a Stmt<'a>>,
@@ -178,7 +193,7 @@ fn eq_stmts(
178193
let new_bindings = &moved_bindings[old_count..];
179194
blocks
180195
.iter()
181-
.all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(s, new_bindings)))
196+
.all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(cx, s, new_bindings)))
182197
} else {
183198
true
184199
}) && blocks.iter().all(|b| get_stmt(b).is_some_and(|s| eq.eq_stmt(s, stmt)))
@@ -218,7 +233,7 @@ fn scan_block_for_eq<'tcx>(
218233
return true;
219234
}
220235
modifies_any_local(cx, stmt, &cond_locals)
221-
|| !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)
236+
|| !eq_stmts(cx, stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)
222237
})
223238
.map_or(block.stmts.len(), |(i, stmt)| {
224239
adjust_by_closest_callsite(i, stmt, block.stmts[..i].iter().enumerate().rev())
@@ -279,6 +294,7 @@ fn scan_block_for_eq<'tcx>(
279294
}))
280295
.fold(end_search_start, |init, (stmt, offset)| {
281296
if eq_stmts(
297+
cx,
282298
stmt,
283299
blocks,
284300
|b| b.stmts.get(b.stmts.len() - offset),
@@ -290,11 +306,26 @@ fn scan_block_for_eq<'tcx>(
290306
// Clear out all locals seen at the end so far. None of them can be moved.
291307
let stmts = &blocks[0].stmts;
292308
for stmt in &stmts[stmts.len() - init..=stmts.len() - offset] {
293-
if let StmtKind::Let(l) = stmt.kind {
294-
l.pat.each_binding_or_first(&mut |_, id, _, _| {
295-
// FIXME(rust/#120456) - is `swap_remove` correct?
296-
eq.locals.swap_remove(&id);
297-
});
309+
match stmt.kind {
310+
StmtKind::Let(l) => {
311+
l.pat.each_binding_or_first(&mut |_, id, _, _| {
312+
// FIXME(rust/#120456) - is `swap_remove` correct?
313+
eq.locals.swap_remove(&id);
314+
});
315+
},
316+
StmtKind::Item(item_id) => {
317+
let item = cx.tcx.hir_item(item_id);
318+
if let ItemKind::Static(..)
319+
| ItemKind::Const(..)
320+
| ItemKind::Fn { .. }
321+
| ItemKind::TyAlias(..)
322+
| ItemKind::Use(..)
323+
| ItemKind::Mod(..) = item.kind
324+
{
325+
eq.local_items.swap_remove(&item.owner_id.to_def_id());
326+
}
327+
},
328+
_ => {},
298329
}
299330
}
300331
moved_locals.truncate(moved_locals_at_start);

clippy_utils/src/hir_utils.rs

Lines changed: 204 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@ use crate::source::{SpanRange, SpanRangeExt, walk_span_to_context};
44
use crate::tokenize_with_text;
55
use rustc_ast::ast;
66
use rustc_ast::ast::InlineAsmTemplatePiece;
7-
use rustc_data_structures::fx::FxHasher;
7+
use rustc_data_structures::fx::{FxHasher, FxIndexMap};
88
use rustc_hir::MatchSource::TryDesugar;
99
use rustc_hir::def::{DefKind, Res};
10+
use rustc_hir::def_id::DefId;
1011
use rustc_hir::{
11-
AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, ByRef, Closure, ConstArg, ConstArgKind, Expr,
12-
ExprField, ExprKind, FnRetTy, GenericArg, GenericArgs, HirId, HirIdMap, InlineAsmOperand, LetExpr, Lifetime,
13-
LifetimeKind, Node, Pat, PatExpr, PatExprKind, PatField, PatKind, Path, PathSegment, PrimTy, QPath, Stmt, StmtKind,
14-
StructTailExpr, TraitBoundModifiers, Ty, TyKind, TyPat, TyPatKind,
12+
AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, ByRef, Closure, ConstArg, ConstArgKind, ConstItemRhs,
13+
Expr, ExprField, ExprKind, FnDecl, FnRetTy, FnSig, GenericArg, GenericArgs, GenericBound, GenericBounds,
14+
GenericParam, GenericParamKind, GenericParamSource, Generics, HirId, HirIdMap, InlineAsmOperand, ItemId, ItemKind,
15+
LetExpr, Lifetime, LifetimeKind, LifetimeParamKind, Node, ParamName, Pat, PatExpr, PatExprKind, PatField, PatKind,
16+
Path, PathSegment, PreciseCapturingArgKind, PrimTy, QPath, Stmt, StmtKind, StructTailExpr, TraitBoundModifiers, Ty,
17+
TyKind, TyPat, TyPatKind, UseKind, WherePredicate, WherePredicateKind,
1518
};
1619
use rustc_lexer::{FrontmatterAllowed, TokenKind, tokenize};
1720
use rustc_lint::LateContext;
@@ -106,6 +109,7 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
106109
left_ctxt: SyntaxContext::root(),
107110
right_ctxt: SyntaxContext::root(),
108111
locals: HirIdMap::default(),
112+
local_items: FxIndexMap::default(),
109113
}
110114
}
111115

@@ -144,6 +148,7 @@ pub struct HirEqInterExpr<'a, 'b, 'tcx> {
144148
// right. For example, when comparing `{ let x = 1; x + 2 }` and `{ let y = 1; y + 2 }`,
145149
// these blocks are considered equal since `x` is mapped to `y`.
146150
pub locals: HirIdMap<HirId>,
151+
pub local_items: FxIndexMap<DefId, DefId>,
147152
}
148153

149154
impl HirEqInterExpr<'_, '_, '_> {
@@ -168,6 +173,189 @@ impl HirEqInterExpr<'_, '_, '_> {
168173
&& self.eq_pat(l.pat, r.pat)
169174
},
170175
(StmtKind::Expr(l), StmtKind::Expr(r)) | (StmtKind::Semi(l), StmtKind::Semi(r)) => self.eq_expr(l, r),
176+
(StmtKind::Item(l), StmtKind::Item(r)) => self.eq_item(*l, *r),
177+
_ => false,
178+
}
179+
}
180+
181+
pub fn eq_item(&mut self, l: ItemId, r: ItemId) -> bool {
182+
let left = self.inner.cx.tcx.hir_item(l);
183+
let right = self.inner.cx.tcx.hir_item(r);
184+
let eq = match (left.kind, right.kind) {
185+
(
186+
ItemKind::Const(l_ident, l_generics, l_ty, ConstItemRhs::Body(l_body)),
187+
ItemKind::Const(r_ident, r_generics, r_ty, ConstItemRhs::Body(r_body)),
188+
) => {
189+
l_ident.name == r_ident.name
190+
&& self.eq_generics(l_generics, r_generics)
191+
&& self.eq_ty(l_ty, r_ty)
192+
&& self.eq_body(l_body, r_body)
193+
},
194+
(ItemKind::Static(l_mut, l_ident, l_ty, l_body), ItemKind::Static(r_mut, r_ident, r_ty, r_body)) => {
195+
l_mut == r_mut && l_ident.name == r_ident.name && self.eq_ty(l_ty, r_ty) && self.eq_body(l_body, r_body)
196+
},
197+
(
198+
ItemKind::Fn {
199+
sig: l_sig,
200+
ident: l_ident,
201+
generics: l_generics,
202+
body: l_body,
203+
has_body: l_has_body,
204+
},
205+
ItemKind::Fn {
206+
sig: r_sig,
207+
ident: r_ident,
208+
generics: r_generics,
209+
body: r_body,
210+
has_body: r_has_body,
211+
},
212+
) => {
213+
l_ident.name == r_ident.name
214+
&& (l_has_body == r_has_body)
215+
&& self.eq_fn_sig(&l_sig, &r_sig)
216+
&& self.eq_generics(l_generics, r_generics)
217+
&& self.eq_body(l_body, r_body)
218+
},
219+
(ItemKind::TyAlias(l_ident, l_generics, l_ty), ItemKind::TyAlias(r_ident, r_generics, r_ty)) => {
220+
l_ident.name == r_ident.name && self.eq_generics(l_generics, r_generics) && self.eq_ty(l_ty, r_ty)
221+
},
222+
(ItemKind::Use(l_path, l_kind), ItemKind::Use(r_path, r_kind)) => {
223+
self.eq_path_segments(l_path.segments, r_path.segments)
224+
&& match (l_kind, r_kind) {
225+
(UseKind::Single(l_ident), UseKind::Single(r_ident)) => l_ident.name == r_ident.name,
226+
(UseKind::Glob, UseKind::Glob) | (UseKind::ListStem, UseKind::ListStem) => true,
227+
_ => false,
228+
}
229+
},
230+
(ItemKind::Mod(l_ident, l_mod), ItemKind::Mod(r_ident, r_mod)) => {
231+
l_ident.name == r_ident.name && over(l_mod.item_ids, r_mod.item_ids, |l, r| self.eq_item(*l, *r))
232+
},
233+
_ => false,
234+
};
235+
if eq {
236+
self.local_items.insert(l.owner_id.to_def_id(), r.owner_id.to_def_id());
237+
}
238+
eq
239+
}
240+
241+
fn eq_fn_sig(&mut self, left: &FnSig<'_>, right: &FnSig<'_>) -> bool {
242+
left.header.safety == right.header.safety
243+
&& left.header.constness == right.header.constness
244+
&& left.header.asyncness == right.header.asyncness
245+
&& left.header.abi == right.header.abi
246+
&& self.eq_fn_decl(left.decl, right.decl)
247+
}
248+
249+
fn eq_fn_decl(&mut self, left: &FnDecl<'_>, right: &FnDecl<'_>) -> bool {
250+
over(left.inputs, right.inputs, |l, r| self.eq_ty(l, r))
251+
&& (match (left.output, right.output) {
252+
(FnRetTy::DefaultReturn(_), FnRetTy::DefaultReturn(_)) => true,
253+
(FnRetTy::Return(l_ty), FnRetTy::Return(r_ty)) => self.eq_ty(l_ty, r_ty),
254+
_ => false,
255+
})
256+
&& left.c_variadic == right.c_variadic
257+
&& left.implicit_self == right.implicit_self
258+
&& left.lifetime_elision_allowed == right.lifetime_elision_allowed
259+
}
260+
261+
fn eq_generics(&mut self, left: &Generics<'_>, right: &Generics<'_>) -> bool {
262+
self.eq_generics_param(left.params, right.params)
263+
&& self.eq_generics_predicate(left.predicates, right.predicates)
264+
}
265+
266+
fn eq_generics_predicate(&mut self, left: &[WherePredicate<'_>], right: &[WherePredicate<'_>]) -> bool {
267+
over(left, right, |l, r| match (l.kind, r.kind) {
268+
(WherePredicateKind::BoundPredicate(l_bound), WherePredicateKind::BoundPredicate(r_bound)) => {
269+
l_bound.origin == r_bound.origin
270+
&& self.eq_ty(l_bound.bounded_ty, r_bound.bounded_ty)
271+
&& self.eq_generics_param(l_bound.bound_generic_params, r_bound.bound_generic_params)
272+
&& self.eq_generics_bound(l_bound.bounds, r_bound.bounds)
273+
},
274+
(WherePredicateKind::RegionPredicate(l_region), WherePredicateKind::RegionPredicate(r_region)) => {
275+
Self::eq_lifetime(l_region.lifetime, r_region.lifetime)
276+
&& self.eq_generics_bound(l_region.bounds, r_region.bounds)
277+
},
278+
(WherePredicateKind::EqPredicate(l_eq), WherePredicateKind::EqPredicate(r_eq)) => {
279+
self.eq_ty(l_eq.lhs_ty, r_eq.lhs_ty)
280+
},
281+
_ => false,
282+
})
283+
}
284+
285+
fn eq_generics_bound(&mut self, left: GenericBounds<'_>, right: GenericBounds<'_>) -> bool {
286+
over(left, right, |l, r| match (l, r) {
287+
(GenericBound::Trait(l_trait), GenericBound::Trait(r_trait)) => {
288+
l_trait.modifiers == r_trait.modifiers
289+
&& self.eq_path(l_trait.trait_ref.path, r_trait.trait_ref.path)
290+
&& self.eq_generics_param(l_trait.bound_generic_params, r_trait.bound_generic_params)
291+
},
292+
(GenericBound::Outlives(l_lifetime), GenericBound::Outlives(r_lifetime)) => {
293+
Self::eq_lifetime(l_lifetime, r_lifetime)
294+
},
295+
(GenericBound::Use(l_capture, _), GenericBound::Use(r_capture, _)) => {
296+
over(l_capture, r_capture, |l, r| match (l, r) {
297+
(PreciseCapturingArgKind::Lifetime(l_lifetime), PreciseCapturingArgKind::Lifetime(r_lifetime)) => {
298+
Self::eq_lifetime(l_lifetime, r_lifetime)
299+
},
300+
(PreciseCapturingArgKind::Param(l_param), PreciseCapturingArgKind::Param(r_param)) => {
301+
l_param.ident == r_param.ident && l_param.res == r_param.res
302+
},
303+
_ => false,
304+
})
305+
},
306+
_ => false,
307+
})
308+
}
309+
310+
fn eq_generics_param(&mut self, left: &[GenericParam<'_>], right: &[GenericParam<'_>]) -> bool {
311+
over(left, right, |l, r| {
312+
(match (l.name, r.name) {
313+
(ParamName::Plain(l_ident), ParamName::Plain(r_ident))
314+
| (ParamName::Error(l_ident), ParamName::Error(r_ident)) => l_ident.name == r_ident.name,
315+
(ParamName::Fresh, ParamName::Fresh) => true,
316+
_ => false,
317+
}) && l.pure_wrt_drop == r.pure_wrt_drop
318+
&& self.eq_generics_param_kind(&l.kind, &r.kind)
319+
&& (matches!(
320+
(l.source, r.source),
321+
(GenericParamSource::Generics, GenericParamSource::Generics)
322+
| (GenericParamSource::Binder, GenericParamSource::Binder)
323+
))
324+
})
325+
}
326+
327+
fn eq_generics_param_kind(&mut self, left: &GenericParamKind<'_>, right: &GenericParamKind<'_>) -> bool {
328+
match (left, right) {
329+
(GenericParamKind::Lifetime { kind: l_kind }, GenericParamKind::Lifetime { kind: r_kind }) => {
330+
match (l_kind, r_kind) {
331+
(LifetimeParamKind::Explicit, LifetimeParamKind::Explicit)
332+
| (LifetimeParamKind::Error, LifetimeParamKind::Error) => true,
333+
(LifetimeParamKind::Elided(l_lifetime_kind), LifetimeParamKind::Elided(r_lifetime_kind)) => {
334+
l_lifetime_kind == r_lifetime_kind
335+
},
336+
_ => false,
337+
}
338+
},
339+
(
340+
GenericParamKind::Type {
341+
default: l_default,
342+
synthetic: l_synthetic,
343+
},
344+
GenericParamKind::Type {
345+
default: r_default,
346+
synthetic: r_synthetic,
347+
},
348+
) => both(*l_default, *r_default, |l, r| self.eq_ty(l, r)) && l_synthetic == r_synthetic,
349+
(
350+
GenericParamKind::Const {
351+
ty: l_ty,
352+
default: l_default,
353+
},
354+
GenericParamKind::Const {
355+
ty: r_ty,
356+
default: r_default,
357+
},
358+
) => self.eq_ty(l_ty, r_ty) && both(*l_default, *r_default, |l, r| self.eq_const_arg(l, r)),
171359
_ => false,
172360
}
173361
}
@@ -563,6 +751,17 @@ impl HirEqInterExpr<'_, '_, '_> {
563751
match (left.res, right.res) {
564752
(Res::Local(l), Res::Local(r)) => l == r || self.locals.get(&l) == Some(&r),
565753
(Res::Local(_), _) | (_, Res::Local(_)) => false,
754+
(Res::Def(l_kind, l), Res::Def(r_kind, r))
755+
if l_kind == r_kind
756+
&& let DefKind::Const
757+
| DefKind::Static { .. }
758+
| DefKind::Fn
759+
| DefKind::TyAlias
760+
| DefKind::Use
761+
| DefKind::Mod = l_kind =>
762+
{
763+
(l == r || self.local_items.get(&l) == Some(&r)) && self.eq_path_segments(left.segments, right.segments)
764+
},
566765
_ => self.eq_path_segments(left.segments, right.segments),
567766
}
568767
}

0 commit comments

Comments
 (0)