Skip to content

[Diagnostics] Diagnose comparisons with '.nan' and suggest using '.isNan' instead #33860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3349,6 +3349,16 @@ ERROR(unordered_adjacent_operators,none,
ERROR(missing_builtin_precedence_group,none,
"broken standard library: missing builtin precedence group %0",
(Identifier))
WARNING(nan_comparison, none,
"comparison with '.nan' using %0 is always %select{false|true}1, use "
"'%2.isNaN' to check if '%3' %select{is not a number|is a number}1",
(Identifier, bool, StringRef, StringRef))
WARNING(nan_comparison_without_isnan, none,
"comparison with '.nan' using %0 is always %select{false|true}1",
(Identifier, bool))
WARNING(nan_comparison_both_nan, none,
"'.nan' %0 '.nan' is always %select{false|true}1",
(StringRef, bool))

// If you change this, also change enum TryKindForDiagnostics.
#define TRY_KIND_SELECT(SUB) "%select{try|try!|try?|await}" #SUB
Expand Down
9 changes: 8 additions & 1 deletion include/swift/AST/Identifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,14 @@ class Identifier {
// Handle the high unicode case out of line.
return isOperatorSlow();
}


// Returns whether this is a standard comparison operator,
// such as '==', '>=' or '!=='.
bool isStandardComparisonOperator() const {
return is("==") || is("!=") || is("===") || is("!==") || is("<") ||
is(">") || is("<=") || is(">=");
}

/// isOperatorStartCodePoint - Return true if the specified code point is a
/// valid start of an operator.
static bool isOperatorStartCodePoint(uint32_t C) {
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ PROTOCOL(StringInterpolationProtocol)
PROTOCOL(AdditiveArithmetic)
PROTOCOL(Differentiable)

PROTOCOL(FloatingPoint)

EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByBooleanLiteral, "BooleanLiteralType", true)
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByDictionaryLiteral, "Dictionary", false)
Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5044,6 +5044,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::StringInterpolationProtocol:
case KnownProtocolKind::AdditiveArithmetic:
case KnownProtocolKind::Differentiable:
case KnownProtocolKind::FloatingPoint:
return SpecialProtocol::None;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/CSDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Type FailureDiagnostic::restoreGenericParameters(
bool FailureDiagnostic::conformsToKnownProtocol(
Type type, KnownProtocolKind protocol) const {
auto &cs = getConstraintSystem();
return constraints::conformsToKnownProtocol(cs, type, protocol);
return constraints::conformsToKnownProtocol(cs.DC, type, protocol);
}

Type RequirementFailure::getOwnerType() const {
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1767,11 +1767,11 @@ namespace {

auto type = contextualType->lookThroughAllOptionalTypes();
if (conformsToKnownProtocol(
CS, type, KnownProtocolKind::ExpressibleByArrayLiteral))
CS.DC, type, KnownProtocolKind::ExpressibleByArrayLiteral))
return false;

return conformsToKnownProtocol(
CS, type, KnownProtocolKind::ExpressibleByDictionaryLiteral);
CS.DC, type, KnownProtocolKind::ExpressibleByDictionaryLiteral);
};

if (isDictionaryContextualType(contextualType)) {
Expand Down
13 changes: 6 additions & 7 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3930,11 +3930,11 @@ bool constraints::hasAppliedSelf(const OverloadChoice &choice,
doesMemberRefApplyCurriedSelf(baseType, decl);
}

bool constraints::conformsToKnownProtocol(ConstraintSystem &cs, Type type,
bool constraints::conformsToKnownProtocol(DeclContext *dc, Type type,
KnownProtocolKind protocol) {
if (auto *proto =
TypeChecker::getProtocol(cs.getASTContext(), SourceLoc(), protocol))
return (bool)TypeChecker::conformsToProtocol(type, proto, cs.DC);
TypeChecker::getProtocol(dc->getASTContext(), SourceLoc(), protocol))
return (bool)TypeChecker::conformsToProtocol(type, proto, dc);
return false;
}

Expand All @@ -3959,7 +3959,8 @@ Type constraints::isRawRepresentable(
ConstraintSystem &cs, Type type,
KnownProtocolKind rawRepresentableProtocol) {
Type rawTy = isRawRepresentable(cs, type);
if (!rawTy || !conformsToKnownProtocol(cs, rawTy, rawRepresentableProtocol))
if (!rawTy ||
!conformsToKnownProtocol(cs.DC, rawTy, rawRepresentableProtocol))
return Type();

return rawTy;
Expand Down Expand Up @@ -4253,9 +4254,7 @@ bool constraints::isStandardComparisonOperator(ASTNode node) {
if (!expr) return false;

if (auto opName = getOperatorName(expr)) {
return opName->is("==") || opName->is("!=") || opName->is("===") ||
opName->is("!==") || opName->is("<") || opName->is(">") ||
opName->is("<=") || opName->is(">=");
return opName->isStandardComparisonOperator();
}
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -5505,7 +5505,7 @@ bool hasAppliedSelf(const OverloadChoice &choice,
llvm::function_ref<Type(Type)> getFixedType);

/// Check whether type conforms to a given known protocol.
bool conformsToKnownProtocol(ConstraintSystem &cs, Type type,
bool conformsToKnownProtocol(DeclContext *dc, Type type,
KnownProtocolKind protocol);

/// Check whether given type conforms to `RawPepresentable` protocol
Expand Down
130 changes: 129 additions & 1 deletion lib/Sema/MiscDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
//===----------------------------------------------------------------------===//

#include "MiscDiagnostics.h"
#include "TypeChecker.h"
#include "ConstraintSystem.h"
#include "TypeCheckAvailability.h"
#include "TypeChecker.h"
#include "swift/AST/ASTWalker.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/NameLookupRequests.h"
Expand All @@ -34,6 +35,7 @@

#define DEBUG_TYPE "Sema"
using namespace swift;
using namespace constraints;

/// Return true if this expression is an implicit promotion from T to T?.
static Expr *isImplicitPromotionToOptional(Expr *E) {
Expand Down Expand Up @@ -4438,6 +4440,131 @@ static void diagnoseExplicitUseOfLazyVariableStorage(const Expr *E,
const_cast<Expr *>(E)->walk(Walker);
}

static void diagnoseComparisonWithNaN(const Expr *E, const DeclContext *DC) {
class ComparisonWithNaNFinder : public ASTWalker {
const ASTContext &C;
const DeclContext *DC;

public:
ComparisonWithNaNFinder(const DeclContext *dc)
: C(dc->getASTContext()), DC(dc) {}

void tryDiagnoseComparisonWithNaN(BinaryExpr *BE) {
ValueDecl *comparisonDecl = nullptr;

// Comparison functions like == or <= take two arguments.
if (BE->getArg()->getNumElements() != 2) {
return;
}

// Dig out the function declaration.
if (auto Fn = BE->getFn()) {
if (auto DSCE = dyn_cast<DotSyntaxCallExpr>(Fn)) {
comparisonDecl = DSCE->getCalledValue();
} else {
comparisonDecl = BE->getCalledValue();
}
}

// Bail out if it isn't a function.
if (!comparisonDecl || !isa<FuncDecl>(comparisonDecl)) {
return;
}

// We're only interested in comparison functions like == or <=.
auto comparisonDeclName = comparisonDecl->getBaseIdentifier();
if (!comparisonDeclName.isStandardComparisonOperator()) {
return;
}

auto firstArg = BE->getArg()->getElement(0);
auto secondArg = BE->getArg()->getElement(1);

// Both arguments must conform to FloatingPoint protocol.
if (!conformsToKnownProtocol(const_cast<DeclContext *>(DC),
firstArg->getType(),
KnownProtocolKind::FloatingPoint) ||
!conformsToKnownProtocol(const_cast<DeclContext *>(DC),
secondArg->getType(),
KnownProtocolKind::FloatingPoint)) {
return;
}

// Convenience utility to extract argument decl.
auto extractArgumentDecl = [&](Expr *arg) -> ValueDecl * {
if (auto DRE = dyn_cast<DeclRefExpr>(arg)) {
return DRE->getDecl();
} else if (auto MRE = dyn_cast<MemberRefExpr>(arg)) {
return MRE->getMember().getDecl();
}
return nullptr;
};

// Dig out the declarations for the arguments.
auto *firstVal = extractArgumentDecl(firstArg);
auto *secondVal = extractArgumentDecl(secondArg);

// If we can't find declarations for both arguments, bail out,
// because one of them has to be '.nan'.
if (!firstArg && !secondArg) {
return;
}

// Convenience utility to check if this is a 'nan' variable.
auto isNanDecl = [&](ValueDecl *VD) {
return VD && isa<VarDecl>(VD) && VD->getBaseIdentifier().is("nan");
};

// Diagnose comparison with '.nan'.
//
// If the comparison is done using '<=', '<', '==', '>', '>=', then
// the result is always false. If the comparison is done using '!=',
// then the result is always true.
//
// Emit a different diagnostic which doesn't mention using '.isNaN' if
// the comparison isn't done using '==' or '!=' or if both sides are
// '.nan'.
if (isNanDecl(firstVal) && isNanDecl(secondVal)) {
C.Diags.diagnose(BE->getLoc(), diag::nan_comparison_both_nan,
comparisonDeclName.str(), comparisonDeclName.is("!="));
} else if (isNanDecl(firstVal) || isNanDecl(secondVal)) {
if (comparisonDeclName.is("==") || comparisonDeclName.is("!=")) {
auto exprStr =
C.SourceMgr
.extractText(Lexer::getCharSourceRangeFromSourceRange(
C.SourceMgr, firstArg->getSourceRange()))
.str();
auto prefix = exprStr;
if (comparisonDeclName.is("!=")) {
prefix = "!" + prefix;
}
C.Diags.diagnose(BE->getLoc(), diag::nan_comparison,
comparisonDeclName, comparisonDeclName.is("!="),
prefix, exprStr);
} else {
C.Diags.diagnose(BE->getLoc(), diag::nan_comparison_without_isnan,
comparisonDeclName, comparisonDeclName.is("!="));
}
}
}

std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
if (!E || isa<ErrorExpr>(E) || !E->getType())
return {false, E};

if (auto *BE = dyn_cast<BinaryExpr>(E)) {
tryDiagnoseComparisonWithNaN(BE);
return {false, E};
}

return {true, E};
}
};

ComparisonWithNaNFinder Walker(DC);
const_cast<Expr *>(E)->walk(Walker);
}

//===----------------------------------------------------------------------===//
// High-level entry points.
//===----------------------------------------------------------------------===//
Expand All @@ -4454,6 +4581,7 @@ void swift::performSyntacticExprDiagnostics(const Expr *E,
diagnoseUnintendedOptionalBehavior(E, DC);
maybeDiagnoseCallToKeyValueObserveMethod(E, DC);
diagnoseExplicitUseOfLazyVariableStorage(E, DC);
diagnoseComparisonWithNaN(E, DC);
if (!ctx.isSwiftVersionAtLeast(5))
diagnoseDeprecatedWritableKeyPath(E, DC);
if (!ctx.LangOpts.DisableAvailabilityChecking)
Expand Down
31 changes: 31 additions & 0 deletions test/decl/var/nan_comparisons.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %target-typecheck-verify-swift

//////////////////////////////////////////////////////////////////////////////////////////////////
/////// Comparison with '.nan' static property instead of using '.isNaN' instance property ///////
//////////////////////////////////////////////////////////////////////////////////////////////////

// One side is '.nan' and the other isn't.
// Using '==' or '!=' for comparison should suggest using '.isNaN'.

let double: Double = 0.0
_ = double == .nan // expected-warning {{comparison with '.nan' using '==' is always false, use 'double.isNaN' to check if 'double' is not a number}}
_ = double != .nan // expected-warning {{comparison with '.nan' using '!=' is always true, use '!double.isNaN' to check if 'double' is a number}}
_ = 0.0 == .nan // // expected-warning {{comparison with '.nan' using '==' is always false, use '0.0.isNaN' to check if '0.0' is not a number}}

// One side is '.nan' and the other isn't. Using '>=', '>', '<', '<=' for comparison:
// We can't suggest using '.isNaN' here.

_ = 0.0 >= .nan // expected-warning {{comparison with '.nan' using '>=' is always false}}
_ = .nan > 1.1 // expected-warning {{comparison with '.nan' using '>' is always false}}
_ = .nan < 2.2 // expected-warning {{comparison with '.nan' using '<' is always false}}
_ = 3.3 <= .nan // expected-warning {{comparison with '.nan' using '<=' is always false}}

// Both sides are '.nan':
// We can't suggest using '.isNaN' here.

_ = Double.nan == Double.nan // expected-warning {{'.nan' == '.nan' is always false}}
_ = Double.nan != Double.nan // expected-warning {{'.nan' != '.nan' is always true}}
_ = Double.nan < Double.nan // expected-warning {{'.nan' < '.nan' is always false}}
_ = Double.nan <= Double.nan // expected-warning {{'.nan' <= '.nan' is always false}}
_ = Double.nan > Double.nan // expected-warning {{'.nan' > '.nan' is always false}}
_ = Double.nan >= Double.nan // expected-warning {{'.nan' >= '.nan' is always false}}