diff --git a/.github/workflows/new-release.yml b/.github/workflows/new-release.yml index 2dee326d..ba5ffe3b 100644 --- a/.github/workflows/new-release.yml +++ b/.github/workflows/new-release.yml @@ -18,8 +18,8 @@ jobs: steps: - name: validate version format run: | - if [[ ! "${{ github.event.inputs.version }}" == *"."* ]]; then - echo "Error: Version must contain a '.'" + if [[ ! "${{ github.event.inputs.version }}" =~ ^0\.[0-9]{3,}$ ]]; then + echo "Error: Version must be in the format 0.XYZ" exit 1 fi - name: create release diff --git a/Analysis/include/Luau/AstUtils.h b/Analysis/include/Luau/AstUtils.h index 0be03925..b914cedc 100644 --- a/Analysis/include/Luau/AstUtils.h +++ b/Analysis/include/Luau/AstUtils.h @@ -9,8 +9,8 @@ namespace Luau { -// Search through the expression 'expr' for types that are known to represent -// uniquely held references. Append these types to 'uniqueTypes'. +// Search through the expression 'expr' for typeArguments that are known to represent +// uniquely held references. Append these typeArguments to 'uniqueTypes'. void findUniqueTypes(NotNull> uniqueTypes, AstExpr* expr, NotNull> astTypes); void findUniqueTypes( diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 5b129dd9..97dcc139 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -99,6 +99,9 @@ struct FunctionCallConstraint class AstExprCall* callSite = nullptr; std::vector> discriminantTypes; + std::vector typeArguments; + std::vector typePackArguments; + // When we dispatch this constraint, we update the key at this map to record // the overload that we selected. DenseHashMap* astOverloadResolvedTypes = nullptr; @@ -292,6 +295,16 @@ struct PushFunctionTypeConstraint bool isSelf; }; +// Binds the function to a set of explicitly specified types, +// for f<>. +struct TypeInstantiationConstraint +{ + TypeId functionType; + TypeId placeholderType; + std::vector typeArguments; + std::vector typePackArguments; +}; + struct PushTypeConstraint { TypeId expectedType; @@ -321,7 +334,8 @@ using ConstraintV = Variant< EqualityConstraint, SimplifyConstraint, PushFunctionTypeConstraint, - PushTypeConstraint>; + PushTypeConstraint, + TypeInstantiationConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 8394de6c..f88d488a 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -333,6 +333,7 @@ struct ConstraintGenerator Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprInterpString* interpString); + Inference check(const ScopePtr& scope, AstExprInstantiate* explicitTypeInstantiation); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary( const ScopePtr& scope, @@ -482,6 +483,8 @@ struct ConstraintGenerator void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block); + std::pair, std::vector> resolveTypeArguments(const ScopePtr& scope, const AstArray& typeArguments); + /** Given a function type annotation, return a vector describing the expected types of the calls to the function * For example, calling a function with annotation ((number) -> string & ((string) -> number)) * yields a vector of size 1, with value: [number | string] diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 2cdfe584..6f9738ec 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -241,7 +241,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint, bool force); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - + bool tryDispatch(const TypeInstantiationConstraint& c, NotNull constraint); bool tryDispatchHasIndexer( int& recursionDepth, @@ -470,6 +470,14 @@ struct ConstraintSolver TypeId simplifyUnion(NotNull scope, Location location, TypeId left, TypeId right); + TypeId instantiateFunctionType( + TypeId functionTypeId, + const std::vector& typeArguments, + const std::vector& typePackArguments, + NotNull scope, + const Location& location + ); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); void throwTimeLimitError() const; diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 6d544c16..7a2435a2 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -189,6 +189,7 @@ struct DataFlowGraphBuilder DataFlowResult visitExpr(AstExprTypeAssertion* t); DataFlowResult visitExpr(AstExprIfElse* i); DataFlowResult visitExpr(AstExprInterpString* i); + DataFlowResult visitExpr(AstExprInstantiate* i); DataFlowResult visitExpr(AstExprError* error); void visitLValue(AstExpr* e, DefId incomingDef); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 62fad5d9..3952043f 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -535,6 +535,36 @@ struct GenericBoundsMismatch bool operator==(const GenericBoundsMismatch& rhs) const; }; +// Used `f<>` where f is not a function +struct InstantiateGenericsOnNonFunction +{ + enum class InterestingEdgeCase + { + None, + MetatableCall, + Intersection, + }; + + InterestingEdgeCase interestingEdgeCase; + + bool operator==(const InstantiateGenericsOnNonFunction&) const; +}; + +// Provided too many generics inside `f<>` +struct TypeInstantiationCountMismatch +{ + std::optional functionName; + TypeId functionType; + + size_t providedTypes = 0; + size_t maximumTypes = 0; + + size_t providedTypePacks = 0; + size_t maximumTypePacks = 0; + + bool operator==(const TypeInstantiationCountMismatch&) const; +}; + // Error when referencing a type function without providing explicit generics. // // type function create_table_with_key() @@ -609,8 +639,9 @@ using TypeErrorData = Variant< MultipleNonviableOverloads, RecursiveRestraintViolation, GenericBoundsMismatch, - UnappliedTypeFunction>; - + UnappliedTypeFunction, + InstantiateGenericsOnNonFunction, + TypeInstantiationCountMismatch>; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index dceecca9..128f4c17 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -140,6 +140,7 @@ std::string dump(const std::optional& ty); std::string dump(TypePackId ty); std::string dump(const std::optional& ty); std::string dump(const std::vector& types); +std::string dump(const std::vector& types); std::string dump(DenseHashMap& types); std::string dump(DenseHashMap& types); @@ -163,4 +164,5 @@ inline std::string toString(const TypeOrPack& tyOrTp) std::string dump(const TypeOrPack& tyOrTp); std::string toStringVector(const std::vector& types, ToStringOptions& opts); +std::string toStringVector(const std::vector& typePacks, ToStringOptions& opts); } // namespace Luau diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 96cc18d5..352c93e7 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -173,6 +173,7 @@ struct TypeChecker2 void visit(AstExprTypeAssertion* expr); void visit(AstExprIfElse* expr); void visit(AstExprInterpString* interpString); + void visit(AstExprInstantiate* explicitTypeInstantiation); void visit(AstExprError* expr); TypeId flattenPack(TypePackId pack); void visitGenerics(AstArray generics, AstArray genericPacks); @@ -233,6 +234,8 @@ struct TypeChecker2 void suggestAnnotations(AstExprFunction* expr, TypeId ty); + void checkTypeInstantiation(AstExpr* baseFunctionExpr, TypeId fnType, const Location& location, const AstArray& typeArguments); + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const; bool isErrorSuppressing(Location loc, TypeId ty); bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2); diff --git a/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h index 4dc57c69..c1315135 100644 --- a/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h +++ b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h @@ -31,6 +31,9 @@ struct TypeFunctionRuntimeBuilderState }; TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state); +TypeFunctionTypePackId serialize(TypePackId tp, TypeFunctionRuntimeBuilderState* state); + TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state); +TypePackId deserialize(TypeFunctionTypePackId tp, TypeFunctionRuntimeBuilderState* state); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 6628df52..a1b45590 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -134,6 +134,7 @@ struct TypeChecker WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprInterpString& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprInstantiate& expr); TypeId checkExprTable( const ScopePtr& scope, @@ -227,6 +228,14 @@ struct TypeChecker const std::vector>& expectedTypes = {} ); + TypeId instantiateTypeParameters( + const ScopePtr& scope, + TypeId baseType, + const AstArray& explicitTypes, + const AstExpr* functionExpr, + const Location& location + ); + static std::optional matchRequire(const AstExprCall& call); TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location); diff --git a/Analysis/src/BuiltinTypeFunctions.cpp b/Analysis/src/BuiltinTypeFunctions.cpp index a99d0a7a..202d2021 100644 --- a/Analysis/src/BuiltinTypeFunctions.cpp +++ b/Analysis/src/BuiltinTypeFunctions.cpp @@ -22,7 +22,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeFamilyApplicationCartesianProductLimit) LUAU_DYNAMIC_FASTINTVARIABLE(LuauStepRefineRecursionLimit, 64) LUAU_FASTFLAG(LuauReduceSetTypeStackPressure) -LUAU_FASTFLAGVARIABLE(LuauRefineNoRefineAlways) +LUAU_FASTFLAGVARIABLE(LuauRefineNoRefineAlways2) LUAU_FASTFLAGVARIABLE(LuauRefineDistributesOverUnions) LUAU_FASTFLAG(LuauEGFixGenericsList) LUAU_FASTFLAG(LuauNoMoreComparisonTypeFunctions) @@ -1304,28 +1304,38 @@ TypeFunctionReductionResult refineTypeFunction( } std::vector discriminantTypes; - for (size_t i = 1; i < typeParams.size(); i++) - discriminantTypes.push_back(follow(typeParams.at(i))); - - if (FFlag::LuauRefineNoRefineAlways) + if (FFlag::LuauRefineNoRefineAlways2) { - bool hasAnyRealRefinements = false; - for (auto discriminant : discriminantTypes) + for (size_t i = 1; i < typeParams.size(); i++) { + auto discriminant = follow(typeParams[i]); + + // Filter out any top level types that are meaningless to refine + // against. + if (is(discriminant)) + continue; + // If the discriminant type is only: - // - The `*no-refine*` type or, - // - tables, metatables, unions, intersections, functions, or negations _containing_ `*no-refine*`. + // - The `*no-refine*` type (covered above) or; + // - tables, metatables, unions, intersections, functions, or + // negations containing `*no-refine*` (covered below). // There's no point in refining against it. ContainsRefinableType crt; crt.traverse(discriminant); - hasAnyRealRefinements = hasAnyRealRefinements || crt.found; + if (crt.found) + discriminantTypes.push_back(discriminant); } // if we don't have any real refinements, i.e. they're all `*no-refine*`, then we can reduce immediately. - if (!hasAnyRealRefinements) + if (discriminantTypes.empty()) return {targetTy, {}}; } + else + { + for (size_t i = 1; i < typeParams.size(); i++) + discriminantTypes.push_back(follow(typeParams.at(i))); + } const bool targetIsPending = isBlockedOrUnsolvedType(targetTy); @@ -1385,8 +1395,8 @@ TypeFunctionReductionResult refineTypeFunction( } else { - // FFlag::LuauRefineNoRefineAlways moves this check upwards so that it runs even if the thing being refined is pending. - if (!FFlag::LuauRefineNoRefineAlways) + // FFlag::LuauRefineNoRefineAlways2 moves this check upwards so that it runs even if the thing being refined is pending. + if (!FFlag::LuauRefineNoRefineAlways2) { // If the discriminant type is only: // - The `*no-refine*` type or, diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 6810e848..46c2a519 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -40,6 +40,7 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTINTVARIABLE(LuauPrimitiveInferenceInTableLimit, 500) LUAU_FASTFLAG(LuauEmplaceNotPushBack) LUAU_FASTFLAG(LuauReduceSetTypeStackPressure) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) LUAU_FASTFLAG(DebugLuauStringSingletonBasedOnQuotes) LUAU_FASTFLAGVARIABLE(LuauPushTypeConstraint2) LUAU_FASTFLAGVARIABLE(LuauEGFixGenericsList) @@ -440,7 +441,7 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio for (DefId operand : phi->operands) { // `scope->lookup(operand)` may return nothing because we only bind a type to that operand - // once we've seen that particular `DefId`. In this case, we need to prototype those types + // once we've seen that particular `DefId`. In this case, we need to prototype those typeArguments // and use those at a later time. std::optional ty = lookup(scope, location, operand, /*prototype*/ false); if (!ty) @@ -608,8 +609,8 @@ namespace /* * Constraint generation may be called upon to simplify an intersection or union - * of types that are not sufficiently solved yet. We use - * FindSimplificationBlockers to recognize these types and defer the + * of typeArguments that are not sufficiently solved yet. We use + * FindSimplificationBlockers to recognize these typeArguments and defer the * simplification until constraint solution. */ struct FindSimplificationBlockers : TypeOnceVisitor @@ -645,7 +646,7 @@ struct FindSimplificationBlockers : TypeOnceVisitor } // We do not need to know anything at all about a function's argument or - // return types in order to simplify it in an intersection or union. + // return typeArguments in order to simplify it in an intersection or union. bool visit(TypeId, const FunctionType&) override { return false; @@ -723,8 +724,8 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat // IntersectConstraint. // For each discriminant ty, we accumulated it onto ty, creating a longer and longer // sequence of refine constraints. On every loop of this we called mustDeferIntersection. - // For sufficiently large types, we would blow the stack. - // Instead, we record all the discriminant types in sequence + // For sufficiently large typeArguments, we would blow the stack. + // Instead, we record all the discriminant typeArguments in sequence // and then dispatch a single refine constraint with multiple arguments. This helps us avoid // the potentially expensive check on mustDeferIntersection std::vector discriminants; @@ -1293,7 +1294,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedModules[name] = moduleInfo->name; - // Imported types of requires that transitively refer to current module have to be replaced with 'any' + // Imported typeArguments of requires that transitively refer to current module have to be replaced with 'any' for (const auto& [location, path] : requireCycles) { if (path.empty() || path.front() != moduleInfo->name) @@ -1730,14 +1731,14 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass if (head.size() >= assign->vars.size) { // If the resultPack is definitely long enough for each variable, we can - // skip the UnpackConstraint and use the result types directly. + // skip the UnpackConstraint and use the result typeArguments directly. for (size_t i = 0; i < assign->vars.size; ++i) valueTypes.push_back(head[i]); } else { - // We're not sure how many types are produced by the right-side + // We're not sure how many typeArguments are produced by the right-side // expressions. We'll use an UnpackConstraint to defer this until // later. for (size_t i = 0; i < assign->vars.size; ++i) @@ -2048,7 +2049,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareExte return ControlFlow::None; } - // We don't have generic extern types, so this assertion _should_ never be hit. + // We don't have generic extern typeArguments, so this assertion _should_ never be hit. LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = follow(lookupType->type); @@ -2109,7 +2110,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareExte bool assignToMetatable = isMetamethod(propName); - // Function types always take 'self', but this isn't reflected in the + // Function typeArguments always take 'self', but this isn't reflected in the // parsed annotation. Add it here. if (prop.isMethod) { @@ -2144,7 +2145,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareExte if (auto readTy = prop.readTy) { // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. + // would create a ton of nested intersection typeArguments. if (const IntersectionType* itv = get(*readTy)) { std::vector options = itv->parts; @@ -2171,7 +2172,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareExte if (auto writeTy = prop.writeTy) { // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. + // would create a ton of nested intersection typeArguments. if (const IntersectionType* itv = get(*writeTy)) { std::vector options = itv->parts; @@ -2582,12 +2583,16 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* TypePackId argPack = addTypePack(std::move(args), argTail); FunctionType ftv(TypeLevel{}, argPack, rets, std::nullopt, call->self); + auto [explicitTypeIds, explicitTypePackIds] = FFlag::LuauExplicitTypeExpressionInstantiation && call->typeArguments.size + ? resolveTypeArguments(scope, call->typeArguments) + : std::pair, std::vector>(); + /* * To make bidirectional type checking work, we need to solve these constraints in a particular order: * * 1. Solve the function type - * 2. Propagate type information from the function type to the argument types - * 3. Solve the argument types + * 2. Propagate type information from the function type to the argument typeArguments + * 3. Solve the argument typeArguments * 4. Solve the call */ @@ -2614,6 +2619,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* rets, call, std::move(discriminantTypes), + std::move(explicitTypeIds), + std::move(explicitTypePackIds), &module->astOverloadResolvedTypes, } ); @@ -2705,6 +2712,11 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExpr* expr, std:: result = check(scope, typeAssert); else if (auto interpString = expr->as()) result = check(scope, interpString); + else if (auto explicitTypeInstantiation = expr->as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + result = check(scope, explicitTypeInstantiation); + } else if (auto err = expr->as()) { // Open question: Should we traverse into this? @@ -3342,6 +3354,61 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprInterpString* return Inference{builtinTypes->stringType}; } +Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprInstantiate* explicitTypeInstantiation) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + TypeId functionType = check(scope, explicitTypeInstantiation->expr, std::nullopt).ty; + + auto [explicitTypeIds, explicitTypePackIds] = resolveTypeArguments(scope, explicitTypeInstantiation->typeArguments); + + TypeId placeholderType = arena->addType(BlockedType{}); + + NotNull constraint = addConstraint( + scope, + explicitTypeInstantiation->location, + TypeInstantiationConstraint{functionType, placeholderType, std::move(explicitTypeIds), std::move(explicitTypePackIds)} + ); + + getMutable(placeholderType)->setOwner(constraint); + + return Inference{placeholderType}; +} + +std::pair, std::vector> ConstraintGenerator::resolveTypeArguments( + const ScopePtr& scope, + const AstArray& typeArguments +) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + std::vector resolvedTypeArguments; + std::vector resolvedTypePackArguments; + + for (const AstTypeOrPack& typeOrPack : typeArguments) + { + if (typeOrPack.type) + { + resolvedTypeArguments.push_back(resolveType( + scope, + typeOrPack.type, + /* inTypeArguments = */ false + )); + } + else + { + LUAU_ASSERT(typeOrPack.typePack); + resolvedTypePackArguments.push_back(resolveTypePack( + scope, + typeOrPack.typePack, + /* inTypeArguments = */ false + )); + } + } + + return {std::move(resolvedTypeArguments), std::move(resolvedTypePackArguments)}; +} + std::tuple ConstraintGenerator::checkBinary( const ScopePtr& scope, AstExprBinary::Op op, @@ -3626,10 +3693,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, for (const AstExprTable::Item& item : expr->items) { - // Expected types are threaded through table literals separately via the + // Expected typeArguments are threaded through table literals separately via the // function matchLiteralType. - // generalize is false here as we want to be able to push types into lambdas in a situation like: + // generalize is false here as we want to be able to push typeArguments into lambdas in a situation like: // // type Callback = (string) -> () // @@ -3774,7 +3841,7 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); // We do not support default values on function generics, so we only - // care about the types involved. + // care about the typeArguments involved. for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); @@ -3924,7 +3991,7 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu } else { - // Some of the types in argTypes will eventually be generics, and some + // Some of the typeArguments in argTypes will eventually be generics, and some // will not. The ones that are not generic will be pruned when // GeneralizationConstraint dispatches. genericTypes.insert(genericTypes.begin(), argTypes.begin(), argTypes.end()); @@ -3943,7 +4010,7 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu TypePackId annotatedRetType = resolveTypePack(signatureScope, fn->returnAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); // We bind the annotated type directly here so that, when we need to - // generate constraints for return types, we have a guarantee that we + // generate constraints for return typeArguments, we have a guarantee that we // know the annotated return type already, if one was provided. LUAU_ASSERT(get(returnType)); emplaceTypePack(asMutable(returnType), annotatedRetType); @@ -4050,7 +4117,7 @@ TypeId ConstraintGenerator::resolveReferenceType( for (const AstTypeOrPack& p : ref->parameters) { - // We do not enforce the ordering of types vs. type packs here; + // We do not enforce the ordering of typeArguments vs. type packs here; // that is done in the parser. if (p.type) { @@ -4060,7 +4127,7 @@ TypeId ConstraintGenerator::resolveReferenceType( { TypePackId tp = resolveTypePack_(scope, p.typePack, /*inTypeArguments*/ true); - // If we need more regular types, we can use single element type packs to fill those in + // If we need more regular typeArguments, we can use single element type packs to fill those in if (parameters.size() < alias->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) parameters.push_back(*first(tp)); else @@ -4779,7 +4846,7 @@ std::vector> ConstraintGenerator::getExpectedCallTypesForF } } - // TODO vvijay Feb 24, 2023 apparently we have to demote the types here? + // TODO vvijay Feb 24, 2023 apparently we have to demote the typeArguments here? return expectedTypes; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 704841f0..eb98e8fa 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -45,14 +45,15 @@ LUAU_FASTFLAGVARIABLE(LuauDontDynamicallyCreateRedundantSubtypeConstraints) LUAU_FASTFLAGVARIABLE(LuauExtendSealedTableUpperBounds) LUAU_FASTFLAG(LuauReduceSetTypeStackPressure) LUAU_FASTFLAG(DebugLuauStringSingletonBasedOnQuotes) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) LUAU_FASTFLAG(LuauPushTypeConstraint2) -LUAU_FASTFLAGVARIABLE(LuauScopedSeenSetInLookupTableProp) LUAU_FASTFLAGVARIABLE(LuauIterableBindNotUnify) LUAU_FASTFLAGVARIABLE(LuauAvoidOverloadSelectionForFunctionType) LUAU_FASTFLAG(LuauSimplifyIntersectionNoTreeSet) LUAU_FASTFLAG(LuauInstantiationUsesGenericPolarity) LUAU_FASTFLAG(LuauPushTypeConstraintLambdas2) LUAU_FASTFLAGVARIABLE(LuauPushTypeConstriantAlwaysCompletes) +LUAU_FASTFLAG(LuauMarkUnscopedGenericsAsSolved) namespace Luau { @@ -881,6 +882,11 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*sc, constraint, force); else if (auto pftc = get(*constraint)) success = tryDispatch(*pftc, constraint); + else if (auto esgc = get(*constraint)) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + success = tryDispatch(*esgc, constraint); + } else if (auto ptc = get(*constraint)) success = tryDispatch(*ptc, constraint, force); else @@ -1534,6 +1540,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, constraint->location); + } + } + fillInDiscriminantTypes(constraint, c.discriminantTypes); OverloadResolver resolver{ @@ -2658,7 +2672,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNulllocation); for (TypeId ity : result.irreducibleTypes) + { uninhabitedTypeFunctions.insert(ity); + if (FFlag::LuauMarkUnscopedGenericsAsSolved) + unblock(ity, constraint->location); + } bool reductionFinished = result.blockedTypes.empty() && result.blockedPacks.empty(); @@ -2886,6 +2904,68 @@ bool ConstraintSolver::tryDispatch(const PushFunctionTypeConstraint& c, NotNull< return true; } +bool ConstraintSolver::tryDispatch(const TypeInstantiationConstraint& c, NotNull constraint) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + bind( + constraint, + c.placeholderType, + instantiateFunctionType(c.functionType, c.typeArguments, c.typePackArguments, constraint->scope, constraint->location) + ); + + return true; +} + +TypeId ConstraintSolver::instantiateFunctionType( + TypeId functionTypeId, + const std::vector& typeArguments, + const std::vector& typePackArguments, + NotNull scope, + const Location& location +) +{ + const FunctionType* ftv = get(follow(functionTypeId)); + if (!ftv) + { + return functionTypeId; + } + + DenseHashMap replacements{nullptr}; + auto typeParametersIter = ftv->generics.begin(); + + for (const TypeId typeArgument : typeArguments) + { + if (typeParametersIter == ftv->generics.end()) + { + break; + } + + replacements[*typeParametersIter++] = typeArgument; + } + + while (typeParametersIter != ftv->generics.end()) + { + replacements[*typeParametersIter++] = freshType(arena, builtinTypes, scope, Polarity::Mixed); + } + + DenseHashMap replacementPacks{nullptr}; + auto typePackParametersIter = ftv->genericPacks.begin(); + + for (const TypePackId typePackArgument : typePackArguments) + { + if (typePackParametersIter == ftv->genericPacks.end()) + { + break; + } + + replacementPacks[*typePackParametersIter++] = typePackArgument; + } + + Replacer replacer{arena, std::move(replacements), std::move(replacementPacks)}; + return replacer.substitute(functionTypeId).value_or(builtinTypes->errorType); +} + bool ConstraintSolver::tryDispatch(const PushTypeConstraint& c, NotNull constraint, bool force) { LUAU_ASSERT(FFlag::LuauPushTypeConstraint2); @@ -3152,12 +3232,7 @@ TablePropLookupResult ConstraintSolver::lookupTableProp( if (seen.contains(subjectType)) return {}; - std::optional, TypeId>> ss; // This won't be needed once LuauScopedSeenSetInLookupTableProp is clipped. - - if (FFlag::LuauScopedSeenSetInLookupTableProp) - ss.emplace(seen, subjectType); - else - seen.insert(subjectType); + ScopedSeenSet, TypeId> ss{seen, subjectType}; subjectType = follow(subjectType); diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 567ae624..7cdf3d9c 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -13,6 +13,7 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) namespace Luau { @@ -849,6 +850,11 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExpr* e) return visitExpr(i); else if (auto i = e->as()) return visitExpr(i); + else if (auto i = e->as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + return visitExpr(i); + } else if (auto error = e->as()) return visitExpr(error); else @@ -1066,6 +1072,27 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprInterpString* i) return {defArena->freshCell(Symbol{}, i->location), nullptr}; } +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprInstantiate* i) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + for (const AstTypeOrPack& typeOrPack : i->typeArguments) + { + if (typeOrPack.type) + { + visitType(typeOrPack.type); + } + else + { + LUAU_ASSERT(typeOrPack.typePack); + visitTypePack(typeOrPack.typePack); + } + } + + return visitExpr(i->expr); +} + + DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprError* error) { DfgScope* unreachable = makeChildScope(); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index be846afe..a4c568df 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,6 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAGVARIABLE(LuauTypeCheckerVectorLerp2) +LUAU_FASTFLAGVARIABLE(LuauTypeCheckerMathIsNanInfFinite) LUAU_FASTFLAGVARIABLE(LuauUseTopTableForTableClearAndIsFrozen) LUAU_FASTFLAGVARIABLE(LuauMorePermissiveNewtableType) @@ -86,6 +87,61 @@ declare bit32: { static constexpr const char* kBuiltinDefinitionMathSrc = R"BUILTIN_SRC( +declare math: { + frexp: @checked (n: number) -> (number, number), + ldexp: @checked (s: number, e: number) -> number, + fmod: @checked (x: number, y: number) -> number, + modf: @checked (n: number) -> (number, number), + pow: @checked (x: number, y: number) -> number, + exp: @checked (n: number) -> number, + + ceil: @checked (n: number) -> number, + floor: @checked (n: number) -> number, + abs: @checked (n: number) -> number, + sqrt: @checked (n: number) -> number, + + log: @checked (n: number, base: number?) -> number, + log10: @checked (n: number) -> number, + + rad: @checked (n: number) -> number, + deg: @checked (n: number) -> number, + + sin: @checked (n: number) -> number, + cos: @checked (n: number) -> number, + tan: @checked (n: number) -> number, + sinh: @checked (n: number) -> number, + cosh: @checked (n: number) -> number, + tanh: @checked (n: number) -> number, + atan: @checked (n: number) -> number, + acos: @checked (n: number) -> number, + asin: @checked (n: number) -> number, + atan2: @checked (y: number, x: number) -> number, + + min: @checked (number, ...number) -> number, + max: @checked (number, ...number) -> number, + + pi: number, + huge: number, + + randomseed: @checked (seed: number) -> (), + random: @checked (number?, number?) -> number, + + sign: @checked (n: number) -> number, + clamp: @checked (n: number, min: number, max: number) -> number, + noise: @checked (x: number, y: number?, z: number?) -> number, + round: @checked (n: number) -> number, + map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number, + lerp: @checked (a: number, b: number, t: number) -> number, + + isnan: @checked (x: number) -> boolean, + isinf: @checked (x: number) -> boolean, + isfinite: @checked (x: number) -> boolean, +} + +)BUILTIN_SRC"; + +static constexpr const char* kBuiltinDefinitionMathSrc_DEPRECATED = R"BUILTIN_SRC( + declare math: { frexp: @checked (n: number) -> (number, number), ldexp: @checked (s: number, e: number) -> number, @@ -355,7 +411,14 @@ std::string getBuiltinDefinitionSource() std::string result = kBuiltinDefinitionBaseSrc; result += kBuiltinDefinitionBit32Src; - result += kBuiltinDefinitionMathSrc; + if (FFlag::LuauTypeCheckerMathIsNanInfFinite) + { + result += kBuiltinDefinitionMathSrc; + } + else + { + result += kBuiltinDefinitionMathSrc_DEPRECATED; + } result += kBuiltinDefinitionOsSrc; result += kBuiltinDefinitionCoroutineSrc; if (FFlag::LuauUseTopTableForTableClearAndIsFrozen) diff --git a/Analysis/src/EqSatSimplification.cpp b/Analysis/src/EqSatSimplification.cpp index beae70f0..9204f186 100644 --- a/Analysis/src/EqSatSimplification.cpp +++ b/Analysis/src/EqSatSimplification.cpp @@ -365,7 +365,7 @@ Id toId( return res; }; - if (auto tt = get(ty)) + if (get(ty)) return egraph.add(TImportedTable{ty}); else if (get(ty)) return egraph.add(TOpaque{ty}); diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index dfe0289e..b3c1a92c 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -919,6 +919,84 @@ struct ErrorConverter "\nbut these types are not compatible with one another."; } + std::string operator()(const InstantiateGenericsOnNonFunction& e) const + { + switch (e.interestingEdgeCase) + { + case InstantiateGenericsOnNonFunction::InterestingEdgeCase::None: + return "Cannot instantiate type parameters on something without type parameters."; + case InstantiateGenericsOnNonFunction::InterestingEdgeCase::MetatableCall: + // `__call` is complicated because `f<>()` is interpreted as `f<>` as its own expression that is then called. + // This is so that you can write code like `local f2 = f<>`, and then call `f2()`. + // With metatables, it's not so obvious what this would result in. + return "Luau does not currently support explicitly instantiating a table with a `__call` metamethod. \ + You may be able to work around this by creating a function that calls the table, and using that instead."; + case InstantiateGenericsOnNonFunction::InterestingEdgeCase::Intersection: + return "Luau does not currently support explicitly instantiating an overloaded function type."; + default: + LUAU_ASSERT(false); + return ""; // MSVC exhaustive + } + } + + std::string operator()(const TypeInstantiationCountMismatch& e) const + { + LUAU_ASSERT(e.providedTypes > e.maximumTypes || e.providedTypePacks > e.maximumTypePacks); + + std::string result = "Too many type parameters passed to "; + + if (e.functionName) + { + result += "'"; + result += *e.functionName; + result += "', which is typed as "; + } + else + { + result += "function typed as "; + } + + result += toString(e.functionType); + result += ". Expected "; + + if (e.providedTypes > e.maximumTypes) + { + result += "at most "; + result += std::to_string(e.maximumTypes); + result += " type parameter"; + if (e.maximumTypes != 1) + { + result += "s"; + } + result += ", but "; + result += std::to_string(e.providedTypes); + result += " provided"; + + if (e.providedTypePacks > e.maximumTypePacks) + { + result += ". Also expected "; + } + } + + if (e.providedTypePacks > e.maximumTypePacks) + { + result += "at most "; + result += std::to_string(e.maximumTypePacks); + result += " type pack"; + if (e.maximumTypePacks != 1) + { + result += "s"; + } + result += ", but "; + result += std::to_string(e.providedTypePacks); + result += " provided"; + } + + result += "."; + + return result; + } + std::string operator()(const UnappliedTypeFunction&) const { return "Type functions always require `<>` when referenced."; @@ -1331,6 +1409,17 @@ bool MultipleNonviableOverloads::operator==(const MultipleNonviableOverloads& rh return attemptedArgCount == rhs.attemptedArgCount; } +bool InstantiateGenericsOnNonFunction::operator==(const InstantiateGenericsOnNonFunction& rhs) const +{ + return interestingEdgeCase == rhs.interestingEdgeCase; +} + +bool TypeInstantiationCountMismatch::operator==(const TypeInstantiationCountMismatch& rhs) const +{ + return functionName == rhs.functionName && functionType == rhs.functionType && providedTypes == rhs.providedTypes && + maximumTypes == rhs.maximumTypes && providedTypePacks == rhs.providedTypePacks && maximumTypePacks == rhs.maximumTypePacks; +} + GenericBoundsMismatch::GenericBoundsMismatch(const std::string_view genericName, TypeIds lowerBoundSet, TypeIds upperBoundSet) : genericName(genericName) , lowerBounds(lowerBoundSet.take()) @@ -1589,6 +1678,13 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) for (auto& upperBound : e.upperBounds) upperBound = clone(upperBound); } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.functionType = clone(e.functionType); + } else if constexpr (std::is_same_v) { } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 7f4c5547..a006e4dd 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -293,6 +293,13 @@ static void errorToString(std::ostream& stream, const T& err) } stream << "] }"; } + else if constexpr (std::is_same_v) + stream << "InstantiateGenericsOnNonFunctionInstantiateGenericsOnNonFunction { interestingEdgeCase = " << err.interestingEdgeCase << " }"; + else if constexpr (std::is_same_v) + stream << "TypeInstantiationCountMismatch { functionName = " << err.functionName.value_or("") + << ", functionType = " << toString(err.functionType) << ", providedTypes = " << err.providedTypes + << ", maximumTypes = " << err.maximumTypes << ", providedTypePacks = " << err.providedTypePacks + << ", maximumTypePacks = " << err.maximumTypePacks << " }"; else if constexpr (std::is_same_v) stream << "UnappliedTypeFunction {}"; else @@ -310,6 +317,22 @@ std::ostream& operator<<(std::ostream& stream, const CannotAssignToNever::Reason } } +std::ostream& operator<<(std::ostream& stream, const InstantiateGenericsOnNonFunction::InterestingEdgeCase& edgeCase) +{ + switch (edgeCase) + { + case InstantiateGenericsOnNonFunction::InterestingEdgeCase::None: + return stream << "None"; + case InstantiateGenericsOnNonFunction::InterestingEdgeCase::MetatableCall: + return stream << "MetatableCall"; + case InstantiateGenericsOnNonFunction::InterestingEdgeCase::Intersection: + return stream << "Intersection"; + default: + LUAU_ASSERT(false); + return stream << "Unknown"; + } +} + std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) { auto cb = [&](const auto& e) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index e5e57baf..af18f9b9 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -18,6 +18,8 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauUnknownGlobalFixSuggestion) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) + namespace Luau { @@ -190,6 +192,11 @@ static bool similar(AstExpr* lhs, AstExpr* rhs) return true; } + CASE(AstExprInstantiate) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + return similar(le->expr, re->expr); + } else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index a09b28c0..5d74f776 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -4,25 +4,28 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Common.h" +#include "Luau/Def.h" +#include "Luau/Error.h" +#include "Luau/Normalize.h" +#include "Luau/RecursionCounter.h" #include "Luau/Simplify.h" -#include "Luau/Type.h" #include "Luau/Subtyping.h" -#include "Luau/Normalize.h" -#include "Luau/Error.h" #include "Luau/TimeTrace.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeArena.h" #include "Luau/TypeFunction.h" -#include "Luau/Def.h" -#include "Luau/ToString.h" #include "Luau/TypeUtils.h" #include LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTINTVARIABLE(LuauNonStrictTypeCheckerRecursionLimit, 300) LUAU_FASTFLAG(LuauEmplaceNotPushBack) LUAU_FASTFLAGVARIABLE(LuauUnreducedTypeFunctionsDontTriggerWarnings) LUAU_FASTFLAGVARIABLE(LuauNonStrictFetchScopeOnce) +LUAU_FASTFLAGVARIABLE(LuauAddRecursionCounterToNonStrictTypeChecker) namespace Luau { @@ -318,6 +321,14 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatBlock* block) { + std::optional _rc; + if (FFlag::LuauAddRecursionCounterToNonStrictTypeChecker) + { + _rc.emplace(&nonStrictRecursionCount); + if (FInt::LuauNonStrictTypeCheckerRecursionLimit > 0 && nonStrictRecursionCount >= FInt::LuauNonStrictTypeCheckerRecursionLimit) + return {}; + } + auto StackPusher = pushStack(block); NonStrictContext ctx; @@ -511,6 +522,14 @@ struct NonStrictTypeChecker NonStrictContext visit(AstExpr* expr, ValueContext context) { + std::optional _rc; + if (FFlag::LuauAddRecursionCounterToNonStrictTypeChecker) + { + _rc.emplace(&nonStrictRecursionCount); + if (FInt::LuauNonStrictTypeCheckerRecursionLimit > 0 && nonStrictRecursionCount >= FInt::LuauNonStrictTypeCheckerRecursionLimit) + return {}; + } + auto pusher = pushStack(expr); if (auto e = expr->as()) return visit(e, context); @@ -809,6 +828,14 @@ struct NonStrictTypeChecker NonStrictContext visit(AstExprTable* table) { + std::optional _rc; + if (FFlag::LuauAddRecursionCounterToNonStrictTypeChecker) + { + _rc.emplace(&nonStrictRecursionCount); + if (FInt::LuauNonStrictTypeCheckerRecursionLimit > 0 && nonStrictRecursionCount >= FInt::LuauNonStrictTypeCheckerRecursionLimit) + return {}; + } + for (auto [_, key, value] : table->items) { if (key) @@ -1291,6 +1318,8 @@ struct NonStrictTypeChecker } private: + int nonStrictRecursionCount = 0; + TypeId getOrCreateNegation(TypeId baseType) { TypeId& cachedResult = cachedNegations[baseType]; @@ -1326,7 +1355,7 @@ void checkNonStrict( typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes); - + module->errors.erase( std::remove_if( module->errors.begin(), diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index 636b0ebf..4ce1fb96 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -27,7 +27,6 @@ LUAU_FASTINTVARIABLE(LuauSubtypingReasoningLimit, 100) LUAU_FASTFLAG(LuauEmplaceNotPushBack) LUAU_FASTFLAGVARIABLE(LuauSubtypingReportGenericBoundMismatches2) LUAU_FASTFLAGVARIABLE(LuauTrackUniqueness) -LUAU_FASTFLAGVARIABLE(LuauSubtypingUnionsAndIntersectionsInGenericBounds) LUAU_FASTFLAGVARIABLE(LuauIndexInMetatableSubtyping) LUAU_FASTFLAGVARIABLE(LuauSubtypingPackRecursionLimits) LUAU_FASTFLAGVARIABLE(LuauSubtypingPrimitiveAndGenericTableTypes) @@ -865,23 +864,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub SubtypingResult result; - if (auto subUnion = get(subTy); subUnion && !FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds) - result = isCovariantWith(env, subUnion, superTy, scope); - else if (auto superUnion = get(superTy); superUnion && !FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds) - { - result = isCovariantWith(env, subTy, superUnion, scope); - if (!result.isSubtype && !result.normalizationTooComplex) - result = trySemanticSubtyping(env, subTy, superTy, scope, result); - } - else if (auto superIntersection = get(superTy); superIntersection && !FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds) - result = isCovariantWith(env, subTy, superIntersection, scope); - else if (auto subIntersection = get(subTy); subIntersection && !FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds) - { - result = isCovariantWith(env, subIntersection, superTy, scope); - if (!result.isSubtype && !result.normalizationTooComplex) - result = trySemanticSubtyping(env, subTy, superTy, scope, result); - } - else if (get(superTy)) + if (get(superTy)) result = {true}; // We have added this as an exception - the set of inhabitants of any is exactly the set of inhabitants of unknown (since error has no @@ -895,8 +878,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub result = isCovariantWith(env, builtinTypes->unknownType, superTy, scope).andAlso(isCovariantWith(env, builtinTypes->errorType, superTy, scope)); } - else if (get(superTy) && // flag delays recursing into unions and inters, so only handle this case if subTy isn't a union or inter - (FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds ? !get(subTy) && !get(subTy) : true)) + else if (get(superTy) && !get(subTy) && !get(subTy)) { LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. @@ -951,25 +933,17 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub } } else if (auto subUnion = get(subTy)) - { - LUAU_ASSERT(FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds); result = isCovariantWith(env, subUnion, superTy, scope); - } else if (auto superUnion = get(superTy)) { - LUAU_ASSERT(FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds); result = isCovariantWith(env, subTy, superUnion, scope); if (!result.isSubtype && !result.normalizationTooComplex) result = trySemanticSubtyping(env, subTy, superTy, scope, result); } else if (auto superIntersection = get(superTy)) - { - LUAU_ASSERT(FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds); result = isCovariantWith(env, subTy, superIntersection, scope); - } else if (auto subIntersection = get(subTy)) { - LUAU_ASSERT(FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds); result = isCovariantWith(env, subIntersection, superTy, scope); if (!result.isSubtype && !result.normalizationTooComplex) result = trySemanticSubtyping(env, subTy, superTy, scope, result); @@ -2797,26 +2771,21 @@ SubtypingResult Subtyping::checkGenericBounds( result.genericBoundsMismatches.emplace_back(genericName, bounds.lowerBound, bounds.upperBound); else if (!boundsResult.isSubtype) { - if (FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds) + // Check if the bounds are error suppressing before reporting a mismatch + switch (shouldSuppressErrors(normalizer, lowerBound).orElse(shouldSuppressErrors(normalizer, upperBound))) { - // Check if the bounds are error suppressing before reporting a mismatch - switch (shouldSuppressErrors(normalizer, lowerBound).orElse(shouldSuppressErrors(normalizer, upperBound))) - { - case ErrorSuppression::Suppress: - break; - case ErrorSuppression::NormalizationFailed: - // intentionally fallthrough here since we couldn't prove this was error-suppressing - [[fallthrough]]; - case ErrorSuppression::DoNotSuppress: - result.genericBoundsMismatches.emplace_back(genericName, bounds.lowerBound, bounds.upperBound); - break; - default: - LUAU_ASSERT(0); - break; - } - } - else + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + // intentionally fallthrough here since we couldn't prove this was error-suppressing + [[fallthrough]]; + case ErrorSuppression::DoNotSuppress: result.genericBoundsMismatches.emplace_back(genericName, bounds.lowerBound, bounds.upperBound); + break; + default: + LUAU_ASSERT(0); + break; + } } result.andAlso(boundsResult); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 7fd0225a..f3acdd48 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1775,6 +1775,11 @@ std::string dump(const std::vector& types) return toStringVector(types, dumpOptions()); } +std::string dump(const std::vector& typePacks) +{ + return toStringVector(typePacks, dumpOptions()); +} + std::string dump(DenseHashMap& types) { std::string s = "{"; @@ -1839,6 +1844,18 @@ std::string toStringVector(const std::vector& types, ToStringOptions& op return s; } +std::string toStringVector(const std::vector& typePacks, ToStringOptions& opts) +{ + std::string s; + for (TypePackId typePack : typePacks) + { + if (!s.empty()) + s += ", "; + s += toString(typePack, opts); + } + return s; +} + std::string toString(const Constraint& constraint, ToStringOptions& opts) { auto go = [&opts](auto&& c) -> std::string @@ -1929,6 +1946,9 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) return "simplify " + tos(c.ty); else if constexpr (std::is_same_v) return "push_function_type " + tos(c.expectedFunctionType) + " => " + tos(c.functionType); + else if constexpr (std::is_same_v) + return "explicitly_specified_constraints " + tos(c.functionType) + " (typeArguments = " + dump(c.typeArguments) + + "), (typePackArguments = " + dump(c.typePackArguments) + ")"; else if constexpr (std::is_same_v) return "push_type " + tos(c.expectedType) + " => " + tos(c.targetType); else diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index cee6427f..100a3bc0 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -9,8 +9,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauOccursCheckInCommit) - namespace Luau { @@ -228,12 +226,7 @@ void TxnLog::commit() { const TypeId unfollowed = &rep.get()->pending; - if (FFlag::LuauOccursCheckInCommit) - { - if (!occurs(*this, unfollowed, ty)) - asMutable(ty)->reassign(*unfollowed); - } - else + if (!occurs(*this, unfollowed, ty)) asMutable(ty)->reassign(*unfollowed); } } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 9919f483..9db062ce 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -34,6 +34,7 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) LUAU_FASTFLAG(LuauNoMoreComparisonTypeFunctions) LUAU_FASTFLAG(LuauTrackUniqueness) @@ -44,7 +45,6 @@ LUAU_FASTFLAGVARIABLE(LuauNewOverloadResolver) LUAU_FASTFLAG(LuauPassBindableGenericsByReference) LUAU_FASTFLAG(LuauSimplifyIntersectionNoTreeSet) LUAU_FASTFLAG(LuauAddRefinementToAssertions) -LUAU_FASTFLAGVARIABLE(LuauNoCustomHandlingOfReasonsingsForForIn) LUAU_FASTFLAGVARIABLE(LuauSuppressIndexingIntoError) namespace Luau @@ -92,10 +92,10 @@ struct StackPusher struct PropertyTypes { - // a vector of all the types assigned to the given property. + // a vector of all the typeArguments assigned to the given property. std::vector typesOfProp; - // a vector of all the types that are missing the given property. + // a vector of all the typeArguments that are missing the given property. std::vector missingProp; bool foundOneProp() const @@ -549,7 +549,7 @@ TypeId TypeChecker2::lookupAnnotation(AstType* annotation) { if (auto ann = ref->parameters.data[0].type) { - TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); + TypeId argTy = lookupAnnotation(ann); luauPrintLine( format("_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str()) ); @@ -752,7 +752,7 @@ void TypeChecker2::visit(AstStatReturn* ret) // // return E0, E1, E2, ... , EN // - // All expressions *except* the last will be types, and the last can + // All expressions *except* the last will be typeArguments, and the last can // potentially be a pack. However, if the last expression is a function // call or varargs (`...`), then we _could_ have a pack in the final // position. Additionally, if we have an argument overflow, then we can't @@ -936,7 +936,7 @@ void TypeChecker2::visit(AstStatForIn* forInStatement) valueTypes.emplace_back(lookupType(firstValue)); } - // if the initial and expected types from the iterator unified during constraint solving, + // if the initial and expected typeArguments from the iterator unified during constraint solving, // we'll have a resolved type to use here, but we'll only use it if either the iterator is // directly present in the for-in statement or if we have an iterator state constraining us TypeId* resolvedTy = module->astForInNextTypes.find(firstValue); @@ -1005,7 +1005,7 @@ void TypeChecker2::visit(AstStatForIn* forInStatement) // first. // It may be invoked with 0 or 1 argument on the first iteration. - // This depends on the types in iterateePack and therefore + // This depends on the typeArguments in iterateePack and therefore // iteratorTypes. // If the iteratee is an error type, then we can't really say anything else about iteration over it. @@ -1095,7 +1095,7 @@ void TypeChecker2::visit(AstStatForIn* forInStatement) * * There must be 1 to 3 iterator arguments. Name them (nextTy, * arrayTy, startIndexTy) * * The return type of nextTy() must correspond to the variables' - * types and counts. HOWEVER the first iterator will never be nil. + * typeArguments and counts. HOWEVER the first iterator will never be nil. * * The first return value of nextTy must be compatible with * startIndexTy. * * The first argument to nextTy() must be compatible with arrayTy if @@ -1262,7 +1262,7 @@ void TypeChecker2::visit(AstStatAssign* assign) } // FIXME CLI-142462: Due to the fact that we do not type state - // tables properly, table types "time travel." We can take + // tables properly, table typeArguments "time travel." We can take // advantage of this for the specific code pattern of: // // local t = {} @@ -1413,6 +1413,11 @@ void TypeChecker2::visit(AstExpr* expr, ValueContext context) return visit(e); else if (auto e = expr->as()) return visit(e); + else if (auto e = expr->as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + return visit(e); + } else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -1440,7 +1445,7 @@ void TypeChecker2::visit(AstExprConstantNil* expr) void TypeChecker2::visit(AstExprConstantBool* expr) { - // booleans use specialized inference logic for singleton types, which can lead to real type errors here. + // booleans use specialized inference logic for singleton typeArguments, which can lead to real type errors here. const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType; const TypeId inferredType = lookupType(expr); @@ -1471,7 +1476,7 @@ void TypeChecker2::visit(AstExprConstantNumber* expr) void TypeChecker2::visit(AstExprConstantString* expr) { - // strings use specialized inference logic for singleton types, which can lead to real type errors here. + // strings use specialized inference logic for singleton typeArguments, which can lead to real type errors here. const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}}); const TypeId inferredType = lookupType(expr); @@ -1569,6 +1574,14 @@ void TypeChecker2::visitCall(AstExprCall* call) return; } + if (FFlag::LuauExplicitTypeExpressionInstantiation) + { + if (call->typeArguments.size) + { + checkTypeInstantiation(call, fnTy, call->location, call->typeArguments); + } + } + if (selectedOverloadTy) { SubtypingResult result = subtyping->isSubtype(*originalCallTy, *selectedOverloadTy, scope); @@ -1607,7 +1620,7 @@ void TypeChecker2::visitCall(AstExprCall* call) } // FIXME: Similar to bidirectional inference prior, this does not support - // overloaded functions nor generic types (yet). + // overloaded functions nor generic typeArguments (yet). if (auto fty = get(fnTy); fty && fty->generics.empty() && fty->genericPacks.empty() && call->args.size > 0) { size_t selfOffset = call->self ? 1 : 0; @@ -2064,7 +2077,7 @@ void TypeChecker2::visit(AstExprIndexExpr* indexExpr, ValueContext context) } else if (auto ut = get(exprType)) { - // if all of the types are a table type, the union must be a table, and so we shouldn't error. + // if all of the typeArguments are a table type, the union must be a table, and so we shouldn't error. if (!std::all_of(begin(ut), end(ut), getTableType)) { if (FFlag::LuauSuppressIndexingIntoError) @@ -2086,7 +2099,7 @@ void TypeChecker2::visit(AstExprIndexExpr* indexExpr, ValueContext context) } else if (auto it = get(exprType)) { - // if any of the types are a table type, the intersection must be a table, and so we shouldn't error. + // if any of the typeArguments are a table type, the intersection must be a table, and so we shouldn't error. if (!std::any_of(begin(it), end(it), getTableType)) reportError(NotATable{exprType}, indexExpr->location); } @@ -2339,7 +2352,7 @@ void TypeChecker2::visit(AstExprUnary* expr) } } -// Comparisons between disjoint types is usually something we warn on, but there are some special exceptions. +// Comparisons between disjoint typeArguments is usually something we warn on, but there are some special exceptions. static bool isOkToCompare( Normalizer& normalizer, NormalizationResult typesHaveIntersection, @@ -2347,7 +2360,7 @@ static bool isOkToCompare( const std::shared_ptr& normRight ) { - // We only consider warning if we know that the types are disjoint. If + // We only consider warning if we know that the typeArguments are disjoint. If // normalization fails here, it should have also failed elsewhere and will // already have been reported. if (NormalizationResult::False != typesHaveIntersection) @@ -2362,7 +2375,7 @@ static bool isOkToCompare( NormalizationResult::True != normalizer.isInhabited(normRight.get())) return true; - // Comparisons between different string singleton types is allowed even + // Comparisons between different string singleton typeArguments is allowed even // if their intersection is technically uninhabited. else if (!normLeft->strings.isNever() && !normRight->strings.isNever()) return true; @@ -2476,8 +2489,8 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey) // If we're working with things that are not tables, the metatable comparisons above are a little excessive // It's ok for one type to have a meta table and the other to not. In that case, we should fall back on - // checking if the intersection of the types is inhabited. If `typesHaveIntersection` failed due to limits, - // TODO: Maybe add more checks here (e.g. for functions, extern types, etc) + // checking if the intersection of the typeArguments is inhabited. If `typesHaveIntersection` failed due to limits, + // TODO: Maybe add more checks here (e.g. for functions, extern typeArguments, etc) if (!(get(leftType) || get(rightType))) if (!leftMt.has_value() || !rightMt.has_value()) matches = matches || typesHaveIntersection != NormalizationResult::False; @@ -2703,7 +2716,7 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey) if (FFlag::LuauNoOrderingTypeFunctions) { // This could be a little wasteful, as we already have normalized - // types, but correctly handles cases like `_: (T & number) <= _: (T & number)`. + // typeArguments, but correctly handles cases like `_: (T & number) <= _: (T & number)`. if (subtyping->isSubtype(leftType, builtinTypes->numberType, scope).isSubtype) { @@ -2813,6 +2826,18 @@ void TypeChecker2::visit(AstExprIfElse* expr) visit(expr->falseExpr, ValueContext::RValue); } +void TypeChecker2::visit(AstExprInstantiate* explicitTypeInstantiation) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + visit(explicitTypeInstantiation->expr, ValueContext::RValue); + checkTypeInstantiation( + explicitTypeInstantiation->expr, + lookupType(explicitTypeInstantiation->expr), + explicitTypeInstantiation->location, + explicitTypeInstantiation->typeArguments + ); +} + void TypeChecker2::visit(AstExprInterpString* interpString) { InConditionalContext inContext(&typeContext, TypeContext::Default); @@ -2972,7 +2997,7 @@ void TypeChecker2::visit(AstTypeReference* ty) } } - // If we require type parameters, but no types are provided and only packs are provided, we report an error. + // If we require type parameters, but no typeArguments are provided and only packs are provided, we report an error. if (typesRequired != 0 && typesProvided == 0 && packsProvided != 0) { reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); @@ -2980,7 +3005,7 @@ void TypeChecker2::visit(AstTypeReference* ty) if (extraTypes != 0 && packsProvided == 0) { - // Extra types are only collected into a pack if a pack is expected + // Extra typeArguments are only collected into a pack if a pack is expected if (packsRequired != 0) packsProvided += 1; else @@ -3569,9 +3594,9 @@ void TypeChecker2::reportErrors(ErrorVec errors) /* A helper for checkIndexTypeFromType. * * Returns a pair: - * * A boolean indicating that at least one of the constituent types + * * A boolean indicating that at least one of the constituent typeArguments * contains the prop, and - * * A vector of types that do not contain the prop. + * * A vector of typeArguments that do not contain the prop. */ PropertyTypes TypeChecker2::lookupProp( const NormalizedType* norm, @@ -3714,7 +3739,7 @@ void TypeChecker2::checkIndexTypeFromType( if (propTypes.foundOneProp()) reportError(MissingUnionProperty{tableTy, propTypes.missingProp, prop}, location); // For class LValues, we don't want to report an extension error, - // because extern types come into being with full knowledge of their + // because extern typeArguments come into being with full knowledge of their // shape. We instead want to report the unknown property error of // the `else` branch. else if (context == ValueContext::LValue && !get(tableTy)) @@ -3885,6 +3910,68 @@ void TypeChecker2::suggestAnnotations(AstExprFunction* expr, TypeId ty) } } +void TypeChecker2::checkTypeInstantiation( + AstExpr* baseFunctionExpr, + TypeId fnType, + const Location& location, + const AstArray& typeArguments +) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + const FunctionType* ftv = get(follow(fnType)); + if (!ftv) + { + InstantiateGenericsOnNonFunction::InterestingEdgeCase interestingEdgeCase = + InstantiateGenericsOnNonFunction::InterestingEdgeCase::None; + + if (findMetatableEntry(builtinTypes, module->errors, fnType, "__call", location).has_value()) + { + interestingEdgeCase = InstantiateGenericsOnNonFunction::InterestingEdgeCase::MetatableCall; + } + else if (get(follow(fnType))) + { + interestingEdgeCase = InstantiateGenericsOnNonFunction::InterestingEdgeCase::Intersection; + } + + reportError( + InstantiateGenericsOnNonFunction{ + interestingEdgeCase, + }, + location + ); + + return; + } + + size_t typeCount = 0; + size_t typePackCount = 0; + + for (const AstTypeOrPack& typeOrPack : typeArguments) + { + if (typeOrPack.type) + { + ++typeCount; + } + else + { + LUAU_ASSERT(typeOrPack.typePack); + ++typePackCount; + } + } + + if (ftv->generics.size() < typeCount || ftv->genericPacks.size() < typePackCount) + { + reportError( + TypeInstantiationCountMismatch{ + getIdentifierOfBaseVar(baseFunctionExpr), fnType, typeCount, ftv->generics.size(), typePackCount, ftv->genericPacks.size() + }, + location + ); + } +} + + void TypeChecker2::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const { std::string_view sv(utk->key); diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index fc54fc7b..6cd1e97c 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -35,6 +35,7 @@ LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies) LUAU_FASTFLAGVARIABLE(LuauEnqueueUnionsOfDistributedTypeFunctions) +LUAU_FASTFLAGVARIABLE(LuauMarkUnscopedGenericsAsSolved) namespace Luau { @@ -586,6 +587,12 @@ struct TypeFunctionReducer // Let the caller know this type will not become reducible result.irreducibleTypes.insert(subject); + if (FFlag::LuauMarkUnscopedGenericsAsSolved) + { + if (getState(subject) == TypeFunctionInstanceState::Unsolved) + setState(subject, TypeFunctionInstanceState::Solved); + } + if (FFlag::DebugLuauLogTypeFamilies) printf("Irreducible due to an unscoped generic type\n"); diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp index a5d5b26c..94899cb3 100644 --- a/Analysis/src/TypeFunctionRuntimeBuilder.cpp +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -1030,9 +1030,19 @@ TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state) return TypeFunctionSerializer(state).serialize(ty); } +TypeFunctionTypePackId serialize(TypePackId tp, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionSerializer(state).serialize(tp); +} + TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state) { return TypeFunctionDeserializer(state).deserialize(ty); } +TypePackId deserialize(TypeFunctionTypePackId tp, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionDeserializer(state).deserialize(tp); +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index fb99401c..a2dd5d7a 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,6 +29,7 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauUseWorkspacePropToChooseSolver) @@ -297,7 +298,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo prepareErrorsForDisplay(currentModule->errors); - // Clear the normalizer caches, since they contain types from the internal type surface + // Clear the normalizer caches, since they contain typeArguments from the internal type surface normalizer.clearCaches(); normalizer.arena = nullptr; @@ -309,7 +310,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo freeze(currentModule->internalTypes); freeze(currentModule->interfaceTypes); - // Clear unifier cache since it's keyed off internal types that get deallocated + // Clear unifier cache since it's keyed off internal typeArguments that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. unifierState.cachedUnify.clear(); unifierState.cachedUnifyError.clear(); @@ -539,7 +540,7 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, // ``` // These both call each other, so `f` will be ordered before `g`, so the call to `g` // is typechecked before `g` has had its body checked. For this reason, there's three - // types for each function: before its body is checked, during checking its body, + // typeArguments for each function: before its body is checked, during checking its body, // and after its body is checked. // // We currently treat the before-type and the during-type as the same, @@ -551,7 +552,7 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, // function g(x) return f(x) end // ``` // The before-type of g is `(X)->Y...` but during type-checking of `f` we will - // unify that with `(number)->number`. The types end up being + // unify that with `(number)->number`. The typeArguments end up being // ``` // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end @@ -1133,7 +1134,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedModules[name] = moduleInfo->name; - // Imported types of requires that transitively refer to current module have to be replaced with 'any' + // Imported typeArguments of requires that transitively refer to current module have to be replaced with 'any' for (const auto& [location, path] : requireCycles) { if (!path.empty() && path.front() == moduleInfo->name) @@ -1213,7 +1214,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) // next is a function that takes Table and an optional index of type K // next(t: Table, index: K | nil) -> (K?, V) - // however, pairs and ipairs are quite messy, but they both share the same types + // however, pairs and ipairs are quite messy, but they both share the same typeArguments // pairs returns 'next, t, nil', thus the type would be // pairs(t: Table) -> ((Table, K | nil) -> (K?, V), Table, K | nil) // ipairs returns 'next, t, 0', thus ipairs will also share the same type as pairs, except K = number @@ -1264,7 +1265,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments - // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only typeArguments for (TypeId var : varTypes) unify(anyType, var, scope, forin.location); @@ -1315,7 +1316,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) if (firstValue->is()) { // Extract the remaining return values of the call - // and check them against the parameter types of the iterator function. + // and check them against the parameter typeArguments of the iterator function. auto [types, tail] = flatten(callRetPack); if (!types.empty()) @@ -1350,7 +1351,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) Position start = firstValue->location.begin; Position end = values[forin.values.size - 1]->location.end; - AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, AstArray{}, Location()}; retPack = checkExprPack(scope, exprCall).type; } @@ -1530,7 +1531,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty if (auto ttv = getMutable(follow(ty))) { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy - // Additionally, we can't modify types that come from other modules + // Additionally, we can't modify typeArguments that come from other modules if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) { bool sameTys = std::equal( @@ -1587,7 +1588,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty } else if (auto mtv = getMutable(follow(ty))) { - // We can't modify types that come from other modules + // We can't modify typeArguments that come from other modules if (follow(ty)->owningArena == ¤tModule->internalTypes) mtv->syntheticName = name; } @@ -1687,7 +1688,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareExternTyp return; } - // We don't have generic extern types, so this assertion _should_ never be hit. + // We don't have generic extern typeArguments, so this assertion _should_ never be hit. LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; @@ -1746,7 +1747,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareExtern bool assignToMetatable = isMetamethod(propName); Luau::ExternType::Props& assignTo = assignToMetatable ? metatable->props : etv->props; - // Function types always take 'self', but this isn't reflected in the + // Function typeArguments always take 'self', but this isn't reflected in the // parsed annotation. Add it here. if (prop.isMethod) { @@ -1777,7 +1778,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareExtern TypeId currentTy = prop.type_DEPRECATED(); // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. + // would create a ton of nested intersection typeArguments. if (const IntersectionType* itv = get(currentTy)) { std::vector options = itv->parts; @@ -1925,6 +1926,11 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp result = checkExpr(scope, *a, expectedType); else if (auto a = expr.as()) result = checkExpr(scope, *a); + else if (auto a = expr.as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + result = checkExpr(scope, *a); + } else ice("Unhandled AstExpr?"); @@ -2156,7 +2162,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( { RecursionLimiter _rl("TypeInfer::UnionType", &recursionCount, FInt::LuauTypeInferRecursionLimit); - // Not needed when we normalize types. + // Not needed when we normalize typeArguments. if (get(follow(t))) return t; @@ -2642,12 +2648,12 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) return std::nullopt; } -/** Return true if comparison between the types a and b should be permitted with +/** Return true if comparison between the typeArguments a and b should be permitted with * the == or ~= operators. * - * Two types are considered eligible for equality testing if it is possible for + * Two typeArguments are considered eligible for equality testing if it is possible for * the test to ever succeed. In other words, we test to see whether the two - * types have any overlap at all. + * typeArguments have any overlap at all. * * In order to make things work smoothly with the greedy solver, this function * exempts any and FreeTypes from this requirement. @@ -2735,7 +2741,7 @@ TypeId TypeChecker::checkRelationalOperation( const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); // Peephole check for `cond and a or b -> type(a)|type(b)` - // TODO: Kill this when singleton types arrive. :( + // TODO: Kill this when singleton typeArguments arrive. :( if (AstExprBinary* subexp = expr.left->as()) { if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) @@ -2775,7 +2781,7 @@ TypeId TypeChecker::checkRelationalOperation( // Unless either type is free or any, an equality comparison is only // valid when the intersection of the two operands is non-empty. // - // eg it is okay to compare string? == number? because the two types + // eg it is okay to compare string? == number? because the two typeArguments // have nil in common, but string == number is not allowed. std::optional eqTestResult = areEqComparable(NotNull{¤tModule->internalTypes}, NotNull{&normalizer}, lhsType, rhsType); if (!eqTestResult) @@ -2959,7 +2965,7 @@ TypeId TypeChecker::checkRelationalOperation( if (get(lhsType)) return unionOfTypes(addType(UnionType{{nilType, singletonType(false)}}), rhsType, scope, expr.location, false); - auto [oty, notNever] = pickTypesFromSense(lhsType, false, neverType); // Filter out falsy types + auto [oty, notNever] = pickTypesFromSense(lhsType, false, neverType); // Filter out falsy typeArguments if (notNever) { @@ -2985,7 +2991,7 @@ TypeId TypeChecker::checkRelationalOperation( } else { - auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy types + auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy typeArguments if (notNever) { @@ -3083,7 +3089,7 @@ TypeId TypeChecker::checkBinaryOperation( if (hasErrors) { // If there are unification errors, the return type may still be unknown - // so we loosen the argument types to see if that helps. + // so we loosen the argument typeArguments to see if that helps. TypePackId fallbackArguments = freshTypePack(scope); TypeId fallbackFunctionType = addType(FunctionType(scope->level, fallbackArguments, retTypePack)); state.errors.clear(); @@ -3207,7 +3213,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else { - // Expected types are not useful for other binary operators. + // Expected typeArguments are not useful for other binary operators. WithPredicate lhs = checkExpr(scope, *expr.left); WithPredicate rhs = checkExpr(scope, *expr.right); @@ -3275,6 +3281,140 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return WithPredicate{stringType}; } +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInstantiate& explicitTypeInstantiation) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + WithPredicate baseType = checkExpr(scope, *explicitTypeInstantiation.expr); + + return WithPredicate{instantiateTypeParameters( + scope, baseType.type, explicitTypeInstantiation.typeArguments, explicitTypeInstantiation.expr, explicitTypeInstantiation.expr->location + )}; +} + +TypeId TypeChecker::instantiateTypeParameters( + const ScopePtr& scope, + TypeId baseType, + const AstArray& explicitTypes, + const AstExpr* functionExpr, + const Location& location +) +{ + baseType = follow(baseType); + const FunctionType* functionType = get(baseType); + + if (!functionType) + { + InstantiateGenericsOnNonFunction::InterestingEdgeCase interestingEdgeCase = + InstantiateGenericsOnNonFunction::InterestingEdgeCase::None; + + if (get(baseType)) + { + interestingEdgeCase = InstantiateGenericsOnNonFunction::InterestingEdgeCase::Intersection; + } + else if (const MetatableType* mttv = get(baseType)) + { + if (getIndexTypeFromType(scope, mttv->metatable, "__call", location, /* addErrors= */ false).has_value()) + { + interestingEdgeCase = InstantiateGenericsOnNonFunction::InterestingEdgeCase::MetatableCall; + } + } + + reportError( + location, + InstantiateGenericsOnNonFunction{ + interestingEdgeCase, + } + ); + + return baseType; + } + + ScopePtr aliasScope = childScope(scope, location); + aliasScope->level = scope->level.incr(); + + std::vector typeParams; + typeParams.reserve(functionType->generics.size()); + for (size_t i = 0; i < functionType->generics.size(); ++i) + { + typeParams.push_back(freshType(scope)); + } + + auto typeParamsIter = typeParams.begin(); + + std::vector typePackParams; + typePackParams.reserve(functionType->genericPacks.size()); + for (size_t i = 0; i < functionType->genericPacks.size(); ++i) + { + typePackParams.push_back(freshTypePack(scope)); + } + + auto typePackParamsIter = typePackParams.begin(); + + size_t typeParamCount = 0; + size_t typePackParamCount = 0; + + for (const AstTypeOrPack& typeOrPack : explicitTypes) + { + if (typeOrPack.type) + { + ++typeParamCount; + + if (typeParamsIter == typeParams.end()) + { + continue; + } + + *typeParamsIter++ = resolveType(scope, *typeOrPack.type); + } + else + { + LUAU_ASSERT(typeOrPack.typePack); + ++typePackParamCount; + + if (typePackParamsIter == typePackParams.end()) + { + continue; + } + + *typePackParamsIter++ = resolveTypePack(scope, *typeOrPack.typePack); + } + } + + if (typeParamCount > functionType->generics.size() || typePackParamCount > functionType->genericPacks.size()) + { + reportError( + location, + TypeInstantiationCountMismatch{ + getFunctionNameAsString(*functionExpr), + baseType, + typeParamCount, + functionType->generics.size(), + typePackParamCount, + functionType->genericPacks.size() + } + ); + } + + TypeFun baseFun; + baseFun.type = baseType; + + baseFun.typeParams.reserve(functionType->generics.size()); + for (TypeId genericId : functionType->generics) + { + baseFun.typeParams.push_back({genericId, std::nullopt}); + } + + baseFun.typePackParams.reserve(functionType->genericPacks.size()); + for (TypePackId genericPackId : functionType->genericPacks) + { + baseFun.typePackParams.push_back({genericPackId, std::nullopt}); + } + + return instantiateTypeFun(scope, baseFun, typeParams, typePackParams, location); +} + + TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) { return checkLValueBinding(scope, expr, ctx); @@ -3727,13 +3867,13 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T // This returns a pair `[funType, funScope]` where // - funType is the prototype type of the function // - funScope is the scope for the function, which is a child scope with bindings added for -// parameters (and generic types if there were explicit generic annotations). +// parameters (and generic typeArguments if there were explicit generic annotations). // -// The function type is a prototype, in that it may be missing some generic types which +// The function type is a prototype, in that it may be missing some generic typeArguments which // can only be inferred from type inference after typechecking the function body. // For example the function `function id(x) return x end` has prototype // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` -// to get type `(X) -> X`, then we quantify the free types to get the final +// to get type `(X) -> X`, then we quantify the free typeArguments to get the final // generic type `(a) -> a`. std::pair TypeChecker::checkFunctionSignature( const ScopePtr& scope, @@ -3758,7 +3898,7 @@ std::pair TypeChecker::checkFunctionSignature( } else if (auto utv = get(follow(*expectedType))) { - // Look for function type in a union. Other types can be ignored since current expression is a function + // Look for function type in a union. Other typeArguments can be ignored since current expression is a function for (auto option : utv) { if (auto ftv = get(follow(option))) @@ -3769,7 +3909,7 @@ std::pair TypeChecker::checkFunctionSignature( } else { - // Do not infer argument types when multiple overloads are expected + // Do not infer argument typeArguments when multiple overloads are expected expectedFunctionType = nullptr; break; } @@ -4136,7 +4276,7 @@ void TypeChecker::checkArgumentList( TypePackId tail = *argIter.tail(); if (state.log.getMutable(tail)) { - // Unify remaining parameters so we don't leave any free-types hanging around. + // Unify remaining parameters so we don't leave any free-typeArguments hanging around. while (paramIter != endIter) { state.tryUnify(errorRecoveryType(anyType), *paramIter); @@ -4273,7 +4413,7 @@ void TypeChecker::checkArgumentList( { loopCount = 0; - // Create a type pack out of the remaining argument types + // Create a type pack out of the remaining argument typeArguments // and unify it with the tail. std::vector rest; rest.reserve(std::distance(argIter, endIter)); @@ -4320,9 +4460,9 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope { // evaluate type of function // decompose an intersection into its component overloads - // Compute types of parameters + // Compute typeArguments of parameters // For each overload - // Compare parameter and argument types + // Compare parameter and argument typeArguments // Report any errors (also speculate dot vs colon warnings!) // Return the resulting return type (even if there are errors) // If there are no matching overloads, unify with (a...) -> (b...) and return b... @@ -4343,7 +4483,13 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, /* addErrors= */ true)) { functionType = *propTy; - actualFunctionType = instantiate(scope, functionType, expr.func->location); + actualFunctionType = instantiate( + scope, + FFlag::LuauExplicitTypeExpressionInstantiation && expr.typeArguments.size + ? instantiateTypeParameters(scope, functionType, expr.typeArguments, expr.func, expr.location) + : functionType, + expr.func->location + ); } else { @@ -4597,7 +4743,7 @@ std::unique_ptr> TypeChecker::checkCallOverload( Unifier state = mkUnifier(scope, expr.location); - // Unify return types + // Unify return typeArguments checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { @@ -4784,7 +4930,7 @@ void TypeChecker::reportOverloadResolutionError( TypeId overload = follow(overloadTypes[i]); Unifier state = mkUnifier(scope, expr.location); - // Unify return types + // Unify return typeArguments if (const FunctionType* ftv = get(overload)) { checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, {}); @@ -5361,7 +5507,7 @@ TypeId TypeChecker::singletonType(bool value) TypeId TypeChecker::singletonType(std::string value) { - // TODO: cache singleton types + // TODO: cache singleton typeArguments return currentModule->internalTypes.addType(Type(SingletonType(StringSingleton{std::move(value)}))); } @@ -5581,7 +5727,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno if (typePackParams.empty() && !extraTypes.empty()) typePackParams.push_back(addTypePack(extraTypes)); - // If we need more regular types, we can use single element type packs to fill those in + // If we need more regular typeArguments, we can use single element type packs to fill those in if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) typeParams.push_back(*first(tp)); else @@ -5606,7 +5752,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno // Add default type and type pack parameters if that's required and it's possible if (notEnoughParameters && hasDefaultParameters) { - // 'applyTypeFunction' is used to substitute default types that reference previous generic types + // 'applyTypeFunction' is used to substitute default typeArguments that reference previous generic typeArguments ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes}; for (size_t i = 0; i < typesProvided; ++i) @@ -5660,7 +5806,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } } - // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + // If we didn't combine regular typeArguments into a type pack and we're still one type pack short, provide an empty type pack if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) typePackParams.push_back(addTypePack({})); @@ -5671,7 +5817,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}} ); - // Pad the types out with error recovery types + // Pad the typeArguments out with error recovery typeArguments while (typeParams.size() < tf->typeParams.size()) typeParams.push_back(errorRecoveryType(scope)); while (typePackParams.size() < tf->typePackParams.size()) @@ -6008,7 +6154,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( Name n = generic->name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that - // a collision can only occur when two generic types have the same name. + // a collision can only occur when two generic typeArguments have the same name. if (scope->privateTypeBindings.count(n) || scope->privateTypePackBindings.count(n)) { // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. @@ -6044,7 +6190,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( Name n = genericPack->name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that - // a collision can only occur when two generic types have the same name. + // a collision can only occur when two generic typeArguments have the same name. if (scope->privateTypePackBindings.count(n) || scope->privateTypeBindings.count(n)) { // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. @@ -6327,7 +6473,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const // If both are subtypes, then we're in one of the two situations: // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ // 2. any <: Instance ∧ Instance <: any - // Right now, we have to look at the types to see if they were undecidables. + // Right now, we have to look at the typeArguments to see if they were undecidables. // By this point, we also know free tables are also subtypes and supertypes. if (optionIsSubtype && targetIsSubtype) { @@ -6468,7 +6614,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc std::vector rhs = options(eqP.type); if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + return; // Optimization: the other side has unknown typeArguments, so there's probably an overlap. Refining is no-op here. auto predicate = [&](TypeId option) -> std::optional { @@ -6542,7 +6688,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp unify(tp, expectedTypePack, scope, location); // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but - // we want to tie up free types to be error types, so we do this instead. + // we want to tie up free typeArguments to be error typeArguments, so we do this instead. currentModule->errors.resize(oldErrorsSize); for (TypeId& tp : expectedPack->head) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 3f932f5e..a93f5154 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -118,6 +118,13 @@ struct AstTypeList AstTypePack* tailType = nullptr; }; +// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use +struct AstTypeOrPack +{ + AstType* type = nullptr; + AstTypePack* typePack = nullptr; +}; + using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName extern int gAstRttiIndex; @@ -415,11 +422,22 @@ class AstExprCall : public AstExpr public: LUAU_RTTI(AstExprCall) - AstExprCall(const Location& location, AstExpr* func, const AstArray& args, bool self, const Location& argLocation); + AstExprCall( + const Location& location, + AstExpr* func, + const AstArray& args, + bool self, + const AstArray& explicitTypes, + const Location& argLocation + ); void visit(AstVisitor* visitor) override; AstExpr* func; + // These will only be filled in specifically `t:f<>()`. + // In `f<>()`, this is parsed as `f<>` as an expression, + // which is then called. + AstArray typeArguments; AstArray args; bool self; Location argLocation; @@ -642,6 +660,20 @@ class AstExprInterpString : public AstExpr AstArray expressions; }; +// f<> +class AstExprInstantiate : public AstExpr +{ +public: + LUAU_RTTI(AstExprInstantiate) + + AstExprInstantiate(const Location& location, AstExpr* expr, AstArray typePack); + + void visit(AstVisitor* visitor) override; + + AstExpr* expr; + AstArray typeArguments; +}; + class AstStatBlock : public AstStat { public: @@ -1071,13 +1103,6 @@ class AstType : public AstNode } }; -// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use -struct AstTypeOrPack -{ - AstType* type = nullptr; - AstTypePack* typePack = nullptr; -}; - class AstTypeReference : public AstType { public: diff --git a/Ast/include/Luau/Cst.h b/Ast/include/Luau/Cst.h index 223c3ddd..153e67a6 100644 --- a/Ast/include/Luau/Cst.h +++ b/Ast/include/Luau/Cst.h @@ -83,6 +83,18 @@ class CstExprConstantString : public CstNode unsigned int blockDepth; }; +// Shared between the expression and call nodes +struct CstTypeInstantiation +{ + Position leftArrow1Position = {0,0}; + Position leftArrow2Position = {0,0}; + + AstArray commaPositions = {}; + + Position rightArrow1Position = {0,0}; + Position rightArrow2Position = {0,0}; +}; + class CstExprCall : public CstNode { public: @@ -93,6 +105,7 @@ class CstExprCall : public CstNode std::optional openParens; std::optional closeParens; AstArray commaPositions; + CstTypeInstantiation* explicitTypes = nullptr; }; class CstExprIndexExpr : public CstNode @@ -192,6 +205,16 @@ class CstExprInterpString : public CstNode AstArray stringPositions; }; +class CstExprExplicitTypeInstantiation : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprExplicitTypeInstantiation) + + explicit CstExprExplicitTypeInstantiation(CstTypeInstantiation instantiation); + + CstTypeInstantiation instantiation; +}; + class CstStatDo : public CstNode { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 0a31f732..ea405de6 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -283,7 +283,7 @@ class Parser // prefixexp -> NAME | '(' expr ')' AstExpr* parsePrefixExpr(); - // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } + // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | TypeInstantiation | `:' NAME [TypeInstantiation] funcargs | funcargs } AstExpr* parsePrimaryExpr(bool asStatement); // asexp -> simpleexp [`::' Type] @@ -310,6 +310,14 @@ class Parser // stringinterp ::= exp { exp} AstExpr* parseInterpString(); + // TypeInstantiation ::= `<' `<' [TypeList] `>' `>' + AstArray parseTypeInstantiationExpr( + CstTypeInstantiation* cstNodeOut = nullptr, + Location* endLocationOut = nullptr + ); + + AstExpr* parseExplicitTypeInstantiationExpr(Position start, AstExpr& basedOnExpr); + // Name std::optional parseNameOpt(const char* context = nullptr); Name parseName(const char* context = nullptr); diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index fc5c4a33..fadd549e 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -4,6 +4,8 @@ #include "Luau/Common.h" #include "Luau/StringUtils.h" +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) + namespace Luau { @@ -32,6 +34,19 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) list.tailType->visit(visitor); } +static void visitTypeOrPackArray(AstVisitor* visitor, const AstArray& arrayOfTypeOrPack) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + for (const AstTypeOrPack& param : arrayOfTypeOrPack) + { + if (param.type) + param.type->visit(visitor); + else + param.typePack->visit(visitor); + } +} + AstAttr::AstAttr(const Location& location, Type type, AstArray args) : AstNode(ClassIndex(), location) , type(type) @@ -210,13 +225,22 @@ void AstExprVarargs::visit(AstVisitor* visitor) visitor->visit(this); } -AstExprCall::AstExprCall(const Location& location, AstExpr* func, const AstArray& args, bool self, const Location& argLocation) +AstExprCall::AstExprCall( + const Location& location, + AstExpr* func, + const AstArray& args, + bool self, + const AstArray& explicitTypes, + const Location& argLocation +) : AstExpr(ClassIndex(), location) , func(func) + , typeArguments(explicitTypes) , args(args) , self(self) , argLocation(argLocation) { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation || explicitTypes.size == 0); } void AstExprCall::visit(AstVisitor* visitor) @@ -522,6 +546,23 @@ void AstExprInterpString::visit(AstVisitor* visitor) } } +AstExprInstantiate::AstExprInstantiate(const Location& location, AstExpr* expr, AstArray types) + : AstExpr(ClassIndex(), location) + , expr(expr) + , typeArguments(types) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); +} + +void AstExprInstantiate::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + visitTypeOrPackArray(visitor, typeArguments); + } +} + + void AstExprError::visit(AstVisitor* visitor) { if (visitor->visit(this)) @@ -1056,12 +1097,19 @@ void AstTypeReference::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (const AstTypeOrPack& param : parameters) + if (FFlag::LuauExplicitTypeExpressionInstantiation) + { + visitTypeOrPackArray(visitor, parameters); + } + else { - if (param.type) - param.type->visit(visitor); - else - param.typePack->visit(visitor); + for (const AstTypeOrPack& param : parameters) + { + if (param.type) + param.type->visit(visitor); + else + param.typePack->visit(visitor); + } } } } diff --git a/Ast/src/Cst.cpp b/Ast/src/Cst.cpp index a4f359a8..e8a06c58 100644 --- a/Ast/src/Cst.cpp +++ b/Ast/src/Cst.cpp @@ -76,6 +76,12 @@ CstExprInterpString::CstExprInterpString(AstArray> sourceStrings, { } +CstExprExplicitTypeInstantiation::CstExprExplicitTypeInstantiation(CstTypeInstantiation instantiation) + : CstNode(CstClassIndex()) + , instantiation(instantiation) +{ +} + CstStatDo::CstStatDo(Position endPosition) : CstNode(CstClassIndex()) , endPosition(endPosition) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index cb92bc2a..6fa59b83 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauSolverV2) LUAU_DYNAMIC_FASTFLAGVARIABLE(DebugLuauReportReturnTypeVariadicWithTypeSuffix, false) LUAU_FASTFLAGVARIABLE(DebugLuauStringSingletonBasedOnQuotes) +LUAU_FASTFLAGVARIABLE(LuauExplicitTypeExpressionInstantiation) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAttributes) // Clip with DebugLuauReportReturnTypeVariadicWithTypeSuffix @@ -3036,7 +3037,40 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) Name index = parseIndexName("method name", opPosition); AstExpr* func = allocator.alloc(Location(start, index.location.end), expr, index.name, index.location, opPosition, ':'); - expr = parseFunctionArgs(func, true); + if (FFlag::LuauExplicitTypeExpressionInstantiation) + { + AstArray typeArguments; + CstTypeInstantiation* cstTypeArguments = options.storeCstData ? allocator.alloc() : nullptr; + + if (lexer.current().type == '<' && lexer.lookahead().type == '<') + { + typeArguments = parseTypeInstantiationExpr(cstTypeArguments); + } + + expr = parseFunctionArgs(func, true); + + if (options.storeCstData) + { + CstNode** cstNode = cstNodeMap.find(expr); + if (cstNode) + { + CstExprCall* exprCall = (*cstNode)->as(); + LUAU_ASSERT(exprCall); + exprCall->explicitTypes = cstTypeArguments; + } + } + + if (typeArguments.size > 0) + { + AstExprCall* call = expr->as(); + LUAU_ASSERT(call); + call->typeArguments = typeArguments; + } + } + else + { + expr = parseFunctionArgs(func, true); + } } else if (lexer.current().type == '(') { @@ -3053,6 +3087,10 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) { expr = parseFunctionArgs(expr, false); } + else if (FFlag::LuauExplicitTypeExpressionInstantiation && lexer.current().type == '<' && lexer.lookahead().type == '<') + { + expr = parseExplicitTypeInstantiationExpr(start, *expr); + } else { break; @@ -3314,7 +3352,8 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) expectMatchAndConsume(')', matchParen); - AstExprCall* node = allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + AstExprCall* node = + allocator.alloc(Location(func->location, end), func, copy(args), self, AstArray{}, Location(argStart, argEnd)); if (options.storeCstData) cstNodeMap[node] = allocator.alloc(matchParen.position, lexer.previousLocation().begin, copy(commaPositions)); return node; @@ -3325,8 +3364,9 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) AstExpr* expr = parseTableConstructor(); Position argEnd = lexer.previousLocation().end; - AstExprCall* node = - allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + AstExprCall* node = allocator.alloc( + Location(func->location, expr->location), func, copy(&expr, 1), self, AstArray{}, Location(argStart, argEnd) + ); if (options.storeCstData) cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); return node; @@ -3336,7 +3376,9 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) Location argLocation = lexer.current().location; AstExpr* expr = parseString(); - AstExprCall* node = allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + AstExprCall* node = allocator.alloc( + Location(func->location, expr->location), func, copy(&expr, 1), self, AstArray{}, argLocation + ); if (options.storeCstData) cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); return node; @@ -4005,6 +4047,72 @@ AstExpr* Parser::parseInterpString() return node; } +LUAU_NOINLINE AstExpr* Parser::parseExplicitTypeInstantiationExpr(Position start, AstExpr& basedOnExpr) +{ + CstExprExplicitTypeInstantiation* cstNode = nullptr; + if (options.storeCstData) + { + cstNode = allocator.alloc(CstTypeInstantiation{}); + } + + Location endLocation; + AstArray typesOrPacks = parseTypeInstantiationExpr(cstNode ? &cstNode->instantiation : nullptr, &endLocation); + + AstExpr* expr = allocator.alloc(Location(start, endLocation.end), &basedOnExpr, typesOrPacks); + + if (options.storeCstData) + { + cstNodeMap[expr] = cstNode; + } + + return expr; +} + +AstArray Parser::parseTypeInstantiationExpr( + CstTypeInstantiation* cstNodeOut, + Location* endLocationOut +) +{ + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + LUAU_ASSERT(lexer.current().type == '<' && lexer.lookahead().type == '<'); + + if (cstNodeOut) + { + cstNodeOut->leftArrow1Position = lexer.current().location.begin; + } + + Lexeme begin = lexer.current(); + lexer.next(); + + TempVector commaPositions = TempVector{scratchPosition}; + + AstArray typeOrPacks = parseTypeParams( + cstNodeOut ? &cstNodeOut->leftArrow2Position : nullptr, + cstNodeOut ? &commaPositions : nullptr, + cstNodeOut ? &cstNodeOut->rightArrow1Position : nullptr + ); + + if (cstNodeOut) + { + cstNodeOut->commaPositions = copy(commaPositions); + + if (lexer.current().type == '>') + { + cstNodeOut->rightArrow2Position = lexer.current().location.begin; + } + } + + if (endLocationOut) + { + *endLocationOut = lexer.current().location; + } + + expectMatchAndConsume('>', begin); + return typeOrPacks; +} + + AstExpr* Parser::parseNumber() { Location start = lexer.current().location; diff --git a/Ast/src/PrettyPrinter.cpp b/Ast/src/PrettyPrinter.cpp index 7e89f201..9111317b 100644 --- a/Ast/src/PrettyPrinter.cpp +++ b/Ast/src/PrettyPrinter.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) + namespace { bool isIdentifierStartChar(char c) @@ -539,6 +541,14 @@ struct Printer const auto cstNode = lookupCstNode(a); + if (FFlag::LuauExplicitTypeExpressionInstantiation) + { + if (writeTypes && (a->typeArguments.size > 0 || (cstNode && cstNode->explicitTypes))) + { + visualizeExplicitTypeInstantiation(a->typeArguments, cstNode && cstNode->explicitTypes ? cstNode->explicitTypes : nullptr); + } + } + if (cstNode) { if (cstNode->openParens) @@ -819,6 +829,22 @@ struct Printer writer.symbol(")"); } + else if (const auto& a = expr.as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + visualize(*a->expr); + + if (writeTypes) + { + const CstExprExplicitTypeInstantiation* cstExprNode = lookupCstNode(a); + + visualizeExplicitTypeInstantiation( + a->typeArguments, + cstExprNode ? &cstExprNode->instantiation : nullptr + ); + } + } else { LUAU_ASSERT(!"Unknown AstExpr"); @@ -1840,6 +1866,54 @@ struct Printer LUAU_ASSERT(!"Unknown AstType"); } } + + void visualizeExplicitTypeInstantiation( + const AstArray& typeArguments, + const CstTypeInstantiation* cstNode + ) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + + if (cstNode) + { + advance(cstNode->leftArrow1Position); + } + writer.symbol("<"); + + if (cstNode) + { + advance(cstNode->leftArrow2Position); + } + writer.symbol("<"); + + CommaSeparatorInserter comma(writer, cstNode ? cstNode->commaPositions.begin() : nullptr); + for (const auto& typeOrPack : typeArguments) + { + if (typeOrPack.type) + { + comma(); + visualizeTypeAnnotation(*typeOrPack.type); + } + else + { + LUAU_ASSERT(typeOrPack.typePack); + comma(); + visualizeTypePackAnnotation(*typeOrPack.typePack, /* forVarArg = */ false); + } + } + + if (cstNode) + { + advance(cstNode->rightArrow1Position); + } + writer.symbol(">"); + + if (cstNode) + { + advance(cstNode->rightArrow2Position); + } + writer.symbol(">"); + } }; std::string toString(AstNode* node) diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 529c4f09..4f8c9181 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -172,7 +172,6 @@ class AssemblyBuilderX64 void vblendvps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, RegisterX64 mask); void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, RegisterX64 mask); - void vblendvpd_DEPRECATED(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3); void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 38316ceb..e5d1f3f2 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -25,7 +25,8 @@ void updateLastUseLocations(IrFunction& function, const std::vector& s uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t startInstIdx); // Returns how many values are coming into the block (live in) and how many are coming out of the block (live out) -std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block); +std::pair getLiveInOutValueCount_NEW(IrFunction& function, IrBlock& start, bool visitChain); +std::pair getLiveInOutValueCount_DEPRECATED(IrFunction& function, IrBlock& block); uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block); diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 3f48dc0c..0bad9cef 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -32,10 +32,12 @@ struct IrBuilder void beginBlock(IrOp block); void loadAndCheckTag(IrOp loc, uint8_t tag, IrOp fallback); + void checkSafeEnv(int pcpos); // Clones all instructions into the current block // Source block that is cloned cannot use values coming in from a predecessor - void clone(const IrBlock& source, bool removeCurrentTerminator); + void clone_NEW(std::vector sourceIdxs, bool removeCurrentTerminator); + void clone_DEPRECATED(const IrBlock& source, bool removeCurrentTerminator); IrOp undef(); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 137cf902..ba4bc609 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -1000,10 +1000,15 @@ enum class IrBlockKind : uint8_t Dead, }; +inline constexpr uint32_t kBlockNoStartPc = ~0u; + +inline constexpr uint8_t kBlockFlagSafeEnvCheck = 1 << 0; +inline constexpr uint8_t kBlockFlagSafeEnvClear = 1 << 1; + struct IrBlock { IrBlockKind kind; - + uint8_t flags = 0; uint16_t useCount = 0; // 'start' and 'finish' define an inclusive range of instructions which belong to this block inside the function @@ -1015,6 +1020,8 @@ struct IrBlock uint32_t chainkey = 0; uint32_t expectedNextBlock = ~0u; + uint32_t startpc = kBlockNoStartPc; + Label label; }; diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index b66fd6f5..ae98a070 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -5,6 +5,8 @@ #include "Luau/Common.h" #include "Luau/IrData.h" +LUAU_FASTFLAG(LuauCodegenFloatLoadStoreProp) + namespace Luau { namespace CodeGen @@ -162,14 +164,30 @@ inline bool hasResult(IrCmd cmd) return false; } -inline bool hasSideEffects(IrCmd cmd) +inline bool canInvalidateSafeEnv(IrCmd cmd) { - if (cmd == IrCmd::INVOKE_FASTCALL) + switch (cmd) + { + case IrCmd::CMP_ANY: + case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + case IrCmd::CONCAT: // TODO: if only strings and numbers are concatenated, there will be no user calls + case IrCmd::CALL: + case IrCmd::FORGLOOP_FALLBACK: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_FORGPREP: return true; + default: + break; + } - // Instructions that don't produce a result most likely have other side-effects to make them useful - // Right now, a full switch would mirror the 'hasResult' function, so we use this simple condition - return !hasResult(cmd); + return false; } inline bool isPseudo(IrCmd cmd) @@ -178,6 +196,19 @@ inline bool isPseudo(IrCmd cmd) return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; } +inline bool hasSideEffects(IrCmd cmd) +{ + if (cmd == IrCmd::INVOKE_FASTCALL) + return true; + + if (FFlag::LuauCodegenFloatLoadStoreProp && isPseudo(cmd)) + return false; + + // Instructions that don't produce a result most likely have other side-effects to make them useful + // Right now, a full switch would mirror the 'hasResult' function, so we use this simple condition + return !hasResult(cmd); +} + IrValueKind getCmdValueKind(IrCmd cmd); bool isGCO(uint8_t tag); @@ -237,5 +268,8 @@ std::vector getSortedBlockOrder(IrFunction& function); // 'dummy' block is returned if the end of array was reached IrBlock& getNextBlock(IrFunction& function, const std::vector& sortedBlocks, IrBlock& dummy, size_t i); +// Returns next block in a chain, marked by 'constPropInBlockChains' optimization pass +IrBlock* tryGetNextBlockInChain(IrFunction& function, IrBlock& block); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 5542fe05..84c2a40f 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -948,12 +948,6 @@ void AssemblyBuilderX64::vblendvps(RegisterX64 dst, RegisterX64 src1, OperandX64 placeAvx("vblendvps", dst, src1, src2, mask.index << 4, 0x4a, false, AVX_0F3A, AVX_66); } -void AssemblyBuilderX64::vblendvpd_DEPRECATED(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3) -{ - // bits [7:4] of imm8 are used to select register for operand 4 - placeAvx("vblendvpd", dst, src1, mask, src3.index << 4, 0x4b, false, AVX_0F3A, AVX_66); -} - void AssemblyBuilderX64::vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, RegisterX64 mask) { // bits [7:4] of imm8 are used to select register for operand 4 diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index 3ea22f3c..627f9676 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -560,6 +560,18 @@ static void applyBuiltinCall(LuauBuiltinFunction bfid, BytecodeTypes& types) types.b = LBC_TYPE_NUMBER; types.c = LBC_TYPE_NUMBER; break; + case LBF_MATH_ISNAN: + types.result = LBC_TYPE_BOOLEAN; + types.a = LBC_TYPE_NUMBER; + break; + case LBF_MATH_ISINF: + types.result = LBC_TYPE_BOOLEAN; + types.a = LBC_TYPE_NUMBER; + break; + case LBF_MATH_ISFINITE: + types.result = LBC_TYPE_BOOLEAN; + types.a = LBC_TYPE_NUMBER; + break; } } diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index ee183c1b..cbddfd45 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -26,6 +26,7 @@ LUAU_FASTFLAG(DebugCodegenOptSize) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) +LUAU_FASTFLAG(LuauCodegenBlockSafeEnv) namespace Luau { @@ -159,6 +160,21 @@ inline bool lowerImpl( if (block.expectedNextBlock != ~0u) CODEGEN_ASSERT(function.getBlockIndex(nextBlock) == block.expectedNextBlock); + // Block might establish a safe environment right at the start + if (FFlag::LuauCodegenBlockSafeEnv && (block.flags & kBlockFlagSafeEnvCheck) != 0) + { + if (options.includeIr) + { + if (options.includeIrPrefix == IncludeIrPrefix::Yes) + build.logAppend("# "); + + build.logAppend(" implicit CHECK_SAFE_ENV exit(%u)\n", block.startpc); + } + + CODEGEN_ASSERT(block.startpc != kBlockNoStartPc); + lowering.checkSafeEnv(IrOp{IrOpKind::VmExit, block.startpc}, nextBlock); + } + for (uint32_t index = block.start; index <= block.finish; index++) { CODEGEN_ASSERT(index < function.instructions.size()); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 0d4b0a1f..f5edb216 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAG(LuauCodegenChainLink) + namespace Luau { namespace CodeGen @@ -142,8 +144,78 @@ uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t s return targetInst.lastUse; } -std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block) +std::pair getLiveInOutValueCount_NEW(IrFunction& function, IrBlock& start, bool visitChain) { + CODEGEN_ASSERT(FFlag::LuauCodegenChainLink); + + // TODO: the function is not called often, but having a small vector would help here + std::vector blocks; + + if (visitChain) + { + for (IrBlock* block = &start; block; block = tryGetNextBlockInChain(function, *block)) + blocks.push_back(function.getBlockIndex(*block)); + } + else + { + blocks.push_back(function.getBlockIndex(start)); + } + + uint32_t liveIns = 0; + uint32_t liveOuts = 0; + + for (uint32_t blockIdx : blocks) + { + const IrBlock& block = function.blocks[blockIdx]; + + // If an operand refers to something inside the current block chain, it completes the instruction we marked as 'live out' + // If it refers to something outside, it has to be a 'live in' + auto checkOp = [function, &blocks, &liveIns, &liveOuts](IrOp op) + { + if (op.kind == IrOpKind::Inst) + { + for (uint32_t blockIdx : blocks) + { + const IrBlock& block = function.blocks[blockIdx]; + + if (op.index >= block.start && op.index <= block.finish) + { + CODEGEN_ASSERT(liveOuts != 0); + liveOuts--; + return; + } + } + + liveIns++; + } + }; + + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + IrInst& inst = function.instructions[instIdx]; + + if (isPseudo(inst.cmd)) + continue; + + liveOuts += inst.useCount; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); + checkOp(inst.g); + } + } + + return std::make_pair(liveIns, liveOuts); +} + +std::pair getLiveInOutValueCount_DEPRECATED(IrFunction& function, IrBlock& block) +{ + CODEGEN_ASSERT(!FFlag::LuauCodegenChainLink); + uint32_t liveIns = 0; uint32_t liveOuts = 0; @@ -181,12 +253,18 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block) { - return getLiveInOutValueCount(function, block).first; + if (FFlag::LuauCodegenChainLink) + return getLiveInOutValueCount_NEW(function, block, false).first; + else + return getLiveInOutValueCount_DEPRECATED(function, block).first; } uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block) { - return getLiveInOutValueCount(function, block).second; + if (FFlag::LuauCodegenChainLink) + return getLiveInOutValueCount_NEW(function, block, false).second; + else + return getLiveInOutValueCount_DEPRECATED(function, block).second; } void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 1cbe87b8..13c4a2ab 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -12,6 +12,10 @@ #include +LUAU_FASTFLAG(LuauCodegenBlockSafeEnv) + +LUAU_FASTFLAG(LuauCodegenChainLink) + namespace Luau { namespace CodeGen @@ -172,7 +176,20 @@ void IrBuilder::buildFunctionIr(Proto* proto) // Begin new block at this instruction if it was in the bytecode or requested during translation if (instIndexToBlock[i] != kNoAssociatedBlockIndex) - beginBlock(blockAtInst(i)); + { + if (FFlag::LuauCodegenBlockSafeEnv) + { + IrOp block = blockAtInst(i); + + beginBlock(block); + + function.blockOp(block).startpc = uint32_t(i); + } + else + { + beginBlock(blockAtInst(i)); + } + } // Numeric for loops require additional processing to maintain loop stack // Notably, this must be performed even when the block is dead so that we maintain the pairing FORNPREP-FORNLOOP @@ -660,8 +677,92 @@ void IrBuilder::loadAndCheckTag(IrOp loc, uint8_t tag, IrOp fallback) inst(IrCmd::CHECK_TAG, inst(IrCmd::LOAD_TAG, loc), constTag(tag), fallback); } -void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) +void IrBuilder::checkSafeEnv(int pcpos) +{ + IrBlock& active = function.blocks[activeBlockIdx]; + + // If the block start is associated with a bytecode position, we can perform an early safeenv check + if (active.startpc != kBlockNoStartPc) + { + // If the block hasn't cleared the safeenv flag yet, we can still set it at block entry + if ((active.flags & kBlockFlagSafeEnvClear) == 0) + active.flags |= kBlockFlagSafeEnvCheck; + } + + inst(IrCmd::CHECK_SAFE_ENV, vmExit(pcpos)); +} + +void IrBuilder::clone_NEW(std::vector sourceIdxs, bool removeCurrentTerminator) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenChainLink); + + DenseHashMap instRedir{~0u}; + + auto redirect = [&instRedir](IrOp& op) + { + if (op.kind == IrOpKind::Inst) + { + if (const uint32_t* newIndex = instRedir.find(op.index)) + op.index = *newIndex; + else + CODEGEN_ASSERT(!"Values can only be used if they are defined in the same block"); + } + }; + + for (uint32_t sourceIdx : sourceIdxs) + { + const IrBlock& source = function.blocks[sourceIdx]; + + if (removeCurrentTerminator && inTerminatedBlock) + { + IrBlock& active = function.blocks[activeBlockIdx]; + IrInst& term = function.instructions[active.finish]; + + kill(function, term); + inTerminatedBlock = false; + } + + for (uint32_t index = source.start; index <= source.finish; index++) + { + CODEGEN_ASSERT(index < function.instructions.size()); + IrInst clone = function.instructions[index]; + + // Skip pseudo instructions to make clone more compact, but validate that they have no users + if (isPseudo(clone.cmd)) + { + CODEGEN_ASSERT(clone.useCount == 0); + continue; + } + + redirect(clone.a); + redirect(clone.b); + redirect(clone.c); + redirect(clone.d); + redirect(clone.e); + redirect(clone.f); + redirect(clone.g); + + addUse(function, clone.a); + addUse(function, clone.b); + addUse(function, clone.c); + addUse(function, clone.d); + addUse(function, clone.e); + addUse(function, clone.f); + addUse(function, clone.g); + + // Instructions that referenced the original will have to be adjusted to use the clone + instRedir[index] = uint32_t(function.instructions.size()); + + // Reconstruct the fresh clone + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f, clone.g); + } + } +} + +void IrBuilder::clone_DEPRECATED(const IrBlock& source, bool removeCurrentTerminator) { + CODEGEN_ASSERT(!FFlag::LuauCodegenChainLink); + DenseHashMap instRedir{~0u}; auto redirect = [&instRedir](IrOp& op) @@ -838,6 +939,12 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, inTerminatedBlock = true; } + if (FFlag::LuauCodegenBlockSafeEnv && canInvalidateSafeEnv(cmd)) + { + // Mark that block has instruction with this flag + function.blocks[activeBlockIdx].flags |= kBlockFlagSafeEnvClear; + } + return {IrOpKind::Inst, index}; } diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 5217ad10..e82ebacc 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -897,6 +897,9 @@ std::string toString(const IrFunction& function, IncludeUseInfo includeUseInfo) continue; } + if ((block.flags & kBlockFlagSafeEnvCheck) != 0) + append(ctx.result, " implicit CHECK_SAFE_ENV exit(%u)\n", block.startpc); + // To allow dumping blocks that are still being constructed, we can't rely on terminator and need a bounds check for (uint32_t index = block.start; index <= block.finish && index < uint32_t(function.instructions.size()); index++) { @@ -910,6 +913,13 @@ std::string toString(const IrFunction& function, IncludeUseInfo includeUseInfo) toStringDetailed(ctx, block, uint32_t(i), inst, index, includeUseInfo); } + if (block.expectedNextBlock != ~0u) + { + append(ctx.result, "; glued to: "); + toString(ctx, ctx.blocks[block.expectedNextBlock], block.expectedNextBlock); + append(ctx.result, "\n"); + } + append(ctx.result, "\n"); } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 3c2454dd..ae874f75 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -12,6 +12,8 @@ #include "lstate.h" #include "lgc.h" +LUAU_FASTFLAG(LuauCodegenBlockSafeEnv) + namespace Luau { namespace CodeGen @@ -1786,13 +1788,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } case IrCmd::CHECK_SAFE_ENV: { - Label fresh; // used when guard aborts execution or jumps to a VM exit - RegisterA64 temp = regs.allocTemp(KindA64::x); - RegisterA64 tempw = castReg(KindA64::w, temp); - build.ldr(temp, mem(rClosure, offsetof(Closure, env))); - build.ldrb(tempw, mem(temp, offsetof(LuaTable, safeenv))); - build.cbz(tempw, getTargetLabel(inst.a, fresh)); - finalizeTargetLabel(inst.a, fresh); + if (FFlag::LuauCodegenBlockSafeEnv) + { + checkSafeEnv(inst.a, next); + } + else + { + Label fresh; // used when guard aborts execution or jumps to a VM exit + RegisterA64 temp = regs.allocTemp(KindA64::x); + RegisterA64 tempw = castReg(KindA64::w, temp); + build.ldr(temp, mem(rClosure, offsetof(Closure, env))); + build.ldrb(tempw, mem(temp, offsetof(LuaTable, safeenv))); + build.cbz(tempw, getTargetLabel(inst.a, fresh)); + finalizeTargetLabel(inst.a, fresh); + } break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -2808,6 +2817,17 @@ void IrLoweringA64::finalizeTargetLabel(IrOp op, Label& fresh) } } +void IrLoweringA64::checkSafeEnv(IrOp target, const IrBlock& next) +{ + Label fresh; // used when guard aborts execution or jumps to a VM exit + RegisterA64 temp = regs.allocTemp(KindA64::x); + RegisterA64 tempw = castReg(KindA64::w, temp); + build.ldr(temp, mem(rClosure, offsetof(Closure, env))); + build.ldrb(tempw, mem(temp, offsetof(LuaTable, safeenv))); + build.cbz(tempw, getTargetLabel(target, fresh)); + finalizeTargetLabel(target, fresh); +} + RegisterA64 IrLoweringA64::tempDouble(IrOp op) { if (op.kind == IrOpKind::Inst) diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 717b5d6a..ce013949 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -38,6 +38,8 @@ struct IrLoweringA64 Label& getTargetLabel(IrOp op, Label& fresh); void finalizeTargetLabel(IrOp op, Label& fresh); + void checkSafeEnv(IrOp target, const IrBlock& next); + // Operand data build helpers // May emit data/address synthesis instructions RegisterA64 tempDouble(IrOp op); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 3648279f..747c7be5 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -16,7 +16,7 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauCodeGenVBlendpdReorder) +LUAU_FASTFLAG(LuauCodegenBlockSafeEnv) namespace Luau { @@ -45,25 +45,6 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); } -void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - - if (src.kind == IrOpKind::Constant) - { - build.vmovss(tmp.reg, build.f32(float(doubleOp(src)))); - } - else if (src.kind == IrOpKind::Inst) - { - build.vcvtsd2ss(tmp.reg, regOp(src), regOp(src)); - } - else - { - CODEGEN_ASSERT(!"Unsupported instruction form"); - } - build.vmovss(dst, tmp.reg); -} - void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { regs.currInstIdx = index; @@ -676,10 +657,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) // If arg < 0 then tmp1 is -1 and mask-bit is 0, result is -1 // If arg == 0 then tmp1 is 0 and mask-bit is 0, result is 0 // If arg > 0 then tmp1 is 0 and mask-bit is 1, result is 1 - if (FFlag::LuauCodeGenVBlendpdReorder) - build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); - else - build.vblendvpd_DEPRECATED(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); + build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); break; } case IrCmd::SELECT_NUM: @@ -697,17 +675,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } if (inst.a.kind == IrOpKind::Inst) - if (FFlag::LuauCodeGenVBlendpdReorder) - build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg); - else - build.vblendvpd_DEPRECATED(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg); + build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg); else { build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); - if (FFlag::LuauCodeGenVBlendpdReorder) - build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg); - else - build.vblendvpd_DEPRECATED(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg); + build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg); } break; } @@ -1577,13 +1549,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) break; case IrCmd::CHECK_SAFE_ENV: { - ScopedRegX64 tmp{regs, SizeX64::qword}; + if (FFlag::LuauCodegenBlockSafeEnv) + { + checkSafeEnv(inst.a, next); + } + else + { + ScopedRegX64 tmp{regs, SizeX64::qword}; - build.mov(tmp.reg, sClosure); - build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); - build.cmp(byte[tmp.reg + offsetof(LuaTable, safeenv)], 0); + build.mov(tmp.reg, sClosure); + build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); + build.cmp(byte[tmp.reg + offsetof(LuaTable, safeenv)], 0); - jumpOrAbortOnUndef(ConditionX64::Equal, inst.a, next); + jumpOrAbortOnUndef(ConditionX64::Equal, inst.a, next); + } break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -2496,6 +2475,36 @@ void IrLoweringX64::jumpOrAbortOnUndef(IrOp target, const IrBlock& next) jumpOrAbortOnUndef(ConditionX64::Count, target, next); } +void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src) +{ + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + if (src.kind == IrOpKind::Constant) + { + build.vmovss(tmp.reg, build.f32(float(doubleOp(src)))); + } + else if (src.kind == IrOpKind::Inst) + { + build.vcvtsd2ss(tmp.reg, regOp(src), regOp(src)); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + } + build.vmovss(dst, tmp.reg); +} + +void IrLoweringX64::checkSafeEnv(IrOp target, const IrBlock& next) +{ + ScopedRegX64 tmp{regs, SizeX64::qword}; + + build.mov(tmp.reg, sClosure); + build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); + build.cmp(byte[tmp.reg + offsetof(LuaTable, safeenv)], 0); + + jumpOrAbortOnUndef(ConditionX64::Equal, target, next); +} + OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) { switch (op.kind) diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index b4b9918a..212b61d2 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -44,6 +44,7 @@ struct IrLoweringX64 void jumpOrAbortOnUndef(IrOp target, const IrBlock& next); void storeDoubleAsFloat(OperandX64 dst, IrOp src); + void checkSafeEnv(IrOp target, const IrBlock& next); // Operand data lookup helpers OperandX64 memRegDoubleOp(IrOp op); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index a7de2e99..6b623d8b 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -9,7 +9,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorLerp2) -LUAU_FASTFLAGVARIABLE(LuauCodeGenFMA) // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results @@ -302,19 +301,9 @@ static BuiltinImplResult translateBuiltinVectorLerp(IrBuilder& build, int nparam IrOp one = build.inst(IrCmd::NUM_TO_VEC, build.constDouble(1.0)); IrOp diff = build.inst(IrCmd::SUB_VEC, b, a); - if (FFlag::LuauCodeGenFMA) - { - IrOp res = build.inst(IrCmd::MULADD_VEC, diff, tvec, a); - IrOp ret = build.inst(IrCmd::SELECT_VEC, res, b, tvec, one); - build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), build.inst(IrCmd::TAG_VECTOR, ret)); - } - else - { - IrOp incr = build.inst(IrCmd::MUL_VEC, diff, tvec); - IrOp res = build.inst(IrCmd::ADD_VEC, a, incr); - IrOp ret = build.inst(IrCmd::SELECT_VEC, res, b, tvec, one); - build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), build.inst(IrCmd::TAG_VECTOR, ret)); - } + IrOp res = build.inst(IrCmd::MULADD_VEC, diff, tvec, a); + IrOp ret = build.inst(IrCmd::SELECT_VEC, res, b, tvec, one); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), build.inst(IrCmd::TAG_VECTOR, ret)); return {BuiltinImplType::Full, 1}; } @@ -342,20 +331,10 @@ static BuiltinImplResult translateBuiltinMathLerp( IrOp b = builtinLoadDouble(build, args); IrOp t = builtinLoadDouble(build, arg3); - if (FFlag::LuauCodeGenFMA) - { - IrOp l = build.inst(IrCmd::MULADD_NUM, build.inst(IrCmd::SUB_NUM, b, a), t, a); - IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0 + IrOp l = build.inst(IrCmd::MULADD_NUM, build.inst(IrCmd::SUB_NUM, b, a), t, a); + IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0 - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r); - } - else - { - IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t)); - IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0 - - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r); - } + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 9fac593a..30d1de0b 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -12,6 +12,8 @@ #include "lstate.h" #include "ltm.h" +LUAU_FASTFLAG(LuauCodegenBlockSafeEnv) + namespace Luau { namespace CodeGen @@ -875,7 +877,10 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool IrOp fallback = build.block(IrBlockKind::Fallback); // In unsafe environment, instead of retrying fastcall at 'pcpos' we side-exit directly to fallback sequence - build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos + getOpLength(opcode))); + if (FFlag::LuauCodegenBlockSafeEnv) + build.checkSafeEnv(pcpos + getOpLength(opcode)); + else + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos + getOpLength(opcode))); BuiltinImplResult br = translateBuiltin( build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, builtinArg3, nparams, nresults, fallback, pcpos + getOpLength(opcode) @@ -1065,7 +1070,11 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo IrOp fallback = build.block(IrBlockKind::Fallback); // fast-path: pairs/next - build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); + if (FFlag::LuauCodegenBlockSafeEnv) + build.checkSafeEnv(pcpos); + else + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); + IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); @@ -1093,7 +1102,11 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp IrOp finish = build.block(IrBlockKind::Internal); // fast-path: ipairs/inext - build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); + if (FFlag::LuauCodegenBlockSafeEnv) + build.checkSafeEnv(pcpos); + else + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); + IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); @@ -1321,7 +1334,11 @@ void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos) int k = LUAU_INSN_D(*pc); uint32_t aux = pc[1]; - build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); + if (FFlag::LuauCodegenBlockSafeEnv) + build.checkSafeEnv(pcpos); + else + build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos)); + build.inst(IrCmd::GET_CACHED_IMPORT, build.vmReg(ra), build.vmConst(k), build.constImport(aux), build.constUint(pcpos + 1)); } diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 53d5c8f4..8ca48027 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -1214,5 +1214,22 @@ IrBlock& getNextBlock(IrFunction& function, const std::vector& sortedB return dummy; } +IrBlock* tryGetNextBlockInChain(IrFunction& function, IrBlock& block) +{ + IrInst& termInst = function.instructions[block.finish]; + + // Follow the strict block chain + if (termInst.cmd == IrCmd::JUMP && termInst.a.kind == IrOpKind::Block) + { + IrBlock& target = function.blockOp(termInst.a); + + // Has to have the same sorting key and a consecutive chain key + if (target.sortkey == block.sortkey && target.chainkey == block.chainkey + 1) + return ⌖ + } + + return nullptr; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 023897f9..8c5c4cbf 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -24,6 +24,10 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenLiveSlotReuseLimit, 8) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks) LUAU_FASTFLAGVARIABLE(LuauCodegenStorePriority) +LUAU_FASTFLAGVARIABLE(LuauCodegenInterruptIsNotForWrites) +LUAU_FASTFLAGVARIABLE(LuauCodegenFloatLoadStoreProp) +LUAU_FASTFLAGVARIABLE(LuauCodegenBlockSafeEnv) +LUAU_FASTFLAGVARIABLE(LuauCodegenChainLink) namespace Luau { @@ -340,13 +344,32 @@ struct ConstPropState return IrInst{loadCmd, op}; } + // For instructions like LOAD_FLOAT which have an extra offset, we need to record that in the versioned instruction + IrInst versionedVmRegLoad(IrCmd loadCmd, IrOp opA, IrOp opB) + { + IrInst inst = versionedVmRegLoad(loadCmd, opA); + inst.b = opB; + return inst; + } + uint32_t* getPreviousInstIndex(const IrInst& inst) { if (uint32_t* prevIdx = valueMap.find(inst)) { - // Previous load might have been removed as unused - if (function.instructions[*prevIdx].useCount != 0) - return prevIdx; + if (FFlag::LuauCodegenFloatLoadStoreProp) + { + IrInst& inst = function.instructions[*prevIdx]; + + // Previous load might have been removed as unused + if (inst.useCount != 0 || hasSideEffects(inst.cmd)) + return prevIdx; + } + else + { + // Previous load might have been removed as unused + if (function.instructions[*prevIdx].useCount != 0) + return prevIdx; + } } return nullptr; @@ -372,6 +395,11 @@ struct ConstPropState if (uint32_t* prevIdx = getPreviousVersionedLoadIndex(IrCmd::LOAD_DOUBLE, vmReg)) return std::make_pair(IrCmd::LOAD_DOUBLE, *prevIdx); } + else if (FFlag::LuauCodegenFloatLoadStoreProp && tag == LUA_TVECTOR) + { + if (uint32_t* prevIdx = getPreviousVersionedLoadIndex(IrCmd::LOAD_FLOAT, vmReg)) + return std::make_pair(IrCmd::LOAD_FLOAT, *prevIdx); + } else if (isGCO(tag)) { if (uint32_t* prevIdx = getPreviousVersionedLoadIndex(IrCmd::LOAD_POINTER, vmReg)) @@ -396,16 +424,17 @@ struct ConstPropState // VM register load can be replaced by a previous load of the same version of the register // If there is no previous load, we record the current one for future lookups - void substituteOrRecordVmRegLoad(IrInst& loadInst) + bool substituteOrRecordVmRegLoad(IrInst& loadInst) { CODEGEN_ASSERT(loadInst.a.kind == IrOpKind::VmReg); // To avoid captured register invalidation tracking in lowering later, values from loads from captured registers are not propagated // This prevents the case where load value location is linked to memory in case of a spill and is then clobbered in a user call if (function.cfg.captured.regs.test(vmRegOp(loadInst.a))) - return; + return false; - IrInst versionedLoad = versionedVmRegLoad(loadInst.cmd, loadInst.a); + IrInst versionedLoad = FFlag::LuauCodegenFloatLoadStoreProp ? versionedVmRegLoad(loadInst.cmd, loadInst.a, loadInst.b) + : versionedVmRegLoad(loadInst.cmd, loadInst.a); // Check if there is a value that already has this version of the register if (uint32_t* prevIdx = getPreviousInstIndex(versionedLoad)) @@ -417,7 +446,7 @@ struct ConstPropState // Substitute load instruction with the previous value substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); - return; + return true; } uint32_t instIdx = function.getInstIndex(loadInst); @@ -426,6 +455,7 @@ struct ConstPropState valueMap[versionedLoad] = instIdx; createRegLink(instIdx, loadInst.a); + return false; } // VM register loads can use the value that was stored in the same Vm register earlier @@ -646,6 +676,9 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_VECTOR_MAX: case LBF_VECTOR_LERP: case LBF_MATH_LERP: + case LBF_MATH_ISNAN: + case LBF_MATH_ISINF: + case LBF_MATH_ISFINITE: break; case LBF_TABLE_INSERT: state.invalidateHeap(); @@ -704,6 +737,41 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; } case IrCmd::LOAD_FLOAT: + if (FFlag::LuauCodegenFloatLoadStoreProp && inst.a.kind == IrOpKind::VmReg) + { + if (state.substituteOrRecordVmRegLoad(inst)) + break; + + IrInst versionedLoad = state.versionedVmRegLoad(IrCmd::LOAD_FLOAT, inst.a); + + // Check if there is a value that already has this version of the register + if (uint32_t* prevIdx = state.getPreviousInstIndex(versionedLoad)) + { + IrInst& store = function.instructions[*prevIdx]; + CODEGEN_ASSERT(store.cmd == IrCmd::STORE_VECTOR); + + IrOp argOp; + + if (std::optional intOp = function.asIntOp(inst.b)) + { + if (*intOp == 0) + argOp = store.b; + else if (*intOp == 4) + argOp = store.c; + else if (*intOp == 8) + argOp = store.d; + } + + if (IrInst* arg = function.asInstOp(argOp)) + { + // Argument can only be re-used if it contains the value of the same precision + if (arg->cmd == IrCmd::LOAD_FLOAT || arg->cmd == IrCmd::BUFFER_READF32) + substitute(function, inst, argOp); + } + + break; + } + } break; case IrCmd::LOAD_TVALUE: if (inst.a.kind == IrOpKind::VmReg) @@ -810,6 +878,17 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::STORE_VECTOR: state.invalidateValue(inst.a); + + // To avoid captured register invalidation tracking in lowering later, values from loads from captured registers are not propagated + if (FFlag::LuauCodegenFloatLoadStoreProp && !function.cfg.captured.regs.test(vmRegOp(inst.a))) + { + // This is different from how other stores use 'forwardVmRegStoreToLoad' + // Instead of mapping a store to a load directly, we map a LOAD_FLOAT without a specific offset to the the store instruction itself + // LOAD_FLOAT will have special path to look up this store and apply additional checks to make sure the argument reuse is valid + // One of the restrictions is that STORE_VECTOR converts double to float, so reusing the source is only possible if it comes from a float + // The register versioning rules will stay the same and follow the correct invalidation + state.valueMap[state.versionedVmRegLoad(IrCmd::LOAD_FLOAT, inst.a)] = index; + } break; case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg || inst.a.kind == IrOpKind::Inst) @@ -1658,7 +1737,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateUserCall(); // TODO: if only strings and numbers are concatenated, there will be no user calls break; case IrCmd::INTERRUPT: - state.invalidateUserCall(); + // While interrupt can observe state and yield/error, interrupt handlers must never change state + if (!FFlag::LuauCodegenInterruptIsNotForWrites) + state.invalidateUserCall(); break; case IrCmd::SETLIST: if (RegisterInfo* info = state.tryGetRegisterInfo(inst.b); info && info->knownTableArraySize >= 0) @@ -1737,6 +1818,13 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s { IrFunction& function = build.function; + if (FFlag::LuauCodegenBlockSafeEnv) + { + // Block might establish a safe environment right at the start + if ((block.flags & kBlockFlagSafeEnvCheck) != 0) + state.inSafeEnv = true; + } + for (uint32_t index = block.start; index <= block.finish; index++) { CODEGEN_ASSERT(index < function.instructions.size()); @@ -1765,15 +1853,25 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite CODEGEN_ASSERT(!visited[blockIdx]); visited[blockIdx] = true; + if (FFlag::LuauCodegenBlockSafeEnv) + { + // If we are still in safe env, block doesn't need to re-establish it + if (state.inSafeEnv && (block->flags & kBlockFlagSafeEnvCheck) != 0) + block->flags &= ~kBlockFlagSafeEnvCheck; + } + constPropInBlock(build, *block, state); - // Value numbering and load/store propagation is not performed between blocks - state.invalidateValuePropagation(); + if (!FFlag::LuauCodegenChainLink) + { + // Value numbering and load/store propagation is not performed between blocks + state.invalidateValuePropagation(); - // Same for table and buffer data propagation - state.invalidateHeapTableData(); - state.invalidateHeapBufferData(); - state.invalidateUserdataData(); + // Same for table and buffer data propagation + state.invalidateHeapTableData(); + state.invalidateHeapBufferData(); + state.invalidateUserdataData(); + } // Blocks in a chain are guaranteed to follow each other // We force that by giving all blocks the same sorting key, but consecutive chain keys @@ -1837,7 +1935,8 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st { // Additional restriction is that to join a block, it cannot produce values that are used in other blocks // And it also can't use values produced in other blocks - auto [liveIns, liveOuts] = getLiveInOutValueCount(function, target); + auto [liveIns, liveOuts] = FFlag::LuauCodegenChainLink ? getLiveInOutValueCount_NEW(function, target, true) + : getLiveInOutValueCount_DEPRECATED(function, target); if (liveIns == 0 && liveOuts == 0) { @@ -1845,6 +1944,26 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st path.push_back(targetIdx); nextBlock = ⌖ + + if (FFlag::LuauCodegenChainLink) + { + for (;;) + { + if (IrBlock* nextInChain = tryGetNextBlockInChain(function, *nextBlock)) + { + uint32_t nextInChainIdx = function.getBlockIndex(*nextInChain); + + visited[nextInChainIdx] = true; + path.push_back(nextInChainIdx); + + nextBlock = nextInChain; + } + else + { + break; + } + } + } } } } @@ -1920,9 +2039,15 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited replace(function, termInst.a, newBlock); // Clone the collected path into our fresh block - for (uint32_t pathBlockIdx : path) - build.clone(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true); - + if (FFlag::LuauCodegenChainLink) + { + build.clone_NEW(path, /* removeCurrentTerminator */ true); + } + else + { + for (uint32_t pathBlockIdx : path) + build.clone_DEPRECATED(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true); + } // If all live in/out data is defined aside from the new block, generate it // Note that liveness information is not strictly correct after optimization passes and may need to be recomputed before next passes // The information generated here is consistent with current state that could be outdated, but still useful in IR inspection diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index bc81a8ca..4b0a5691 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -674,7 +674,13 @@ enum LuauBuiltinFunction // math.lerp LBF_MATH_LERP, + // vector.lerp LBF_VECTOR_LERP, + + // math. + LBF_MATH_ISNAN, + LBF_MATH_ISINF, + LBF_MATH_ISFINITE }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 2379e8dd..dee583f6 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -37,7 +37,7 @@ struct CompileOptions int debugLevel = 1; // type information is used to guide native code generation decisions - // information includes testable types for function arguments, locals, upvalues and some temporaries + // information includes testable typeArguments for function arguments, locals, upvalues and some temporaries // 0 - generate for native modules // 1 - generate for all modules int typeInfoLevel = 0; @@ -57,7 +57,7 @@ struct CompileOptions // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these const char* const* mutableGlobals = nullptr; - // null-terminated array of userdata types that will be included in the type information + // null-terminated array of userdata typeArguments that will be included in the type information const char* const* userdataTypes = nullptr; // null-terminated array of globals which act as libraries and have members with known type and/or constant value diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp index 69a30404..ebc78710 100644 --- a/Compiler/src/BuiltinFolding.cpp +++ b/Compiler/src/BuiltinFolding.cpp @@ -522,6 +522,33 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) return cnum(v); } break; + + case LBF_MATH_ISNAN: + if (count == 1 && args[0].type == Constant::Type_Number) + { + double x = args[0].valueNumber; + + return cbool(isnan(x)); + } + break; + + case LBF_MATH_ISINF: + if (count == 1 && args[0].type == Constant::Type_Number) + { + double x = args[0].valueNumber; + + return cbool(isinf(x)); + } + break; + + case LBF_MATH_ISFINITE: + if (count == 1 && args[0].type == Constant::Type_Number) + { + double x = args[0].valueNumber; + + return cbool(isfinite(x)); + } + break; } return cvar(); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index b6cdf8c6..66d87d68 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauCompileVectorLerp) +LUAU_FASTFLAGVARIABLE(LuauCompileMathIsNanInfFinite) namespace Luau { @@ -141,6 +142,16 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_MATH_ROUND; if (builtin.method == "lerp") return LBF_MATH_LERP; + + if (FFlag::LuauCompileMathIsNanInfFinite) + { + if (builtin.method == "isnan") + return LBF_MATH_ISNAN; + if (builtin.method == "isinf") + return LBF_MATH_ISINF; + if (builtin.method == "isfinite") + return LBF_MATH_ISFINITE; + } } if (builtin.object == "bit32") @@ -565,6 +576,12 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_MATH_LERP: return {3, 1, BuiltinInfo::Flag_NoneSafe}; + case LBF_MATH_ISNAN: + return {1, 1, BuiltinInfo::Flag_NoneSafe}; + case LBF_MATH_ISINF: + return {1, 1, BuiltinInfo::Flag_NoneSafe}; + case LBF_MATH_ISFINITE: + return {1, 1, BuiltinInfo::Flag_NoneSafe}; } LUAU_UNREACHABLE(); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index e70b4b18..cb5639c1 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileUnusedUdataFix) + namespace Luau { @@ -711,8 +713,11 @@ void BytecodeBuilder::finalize() // Write the mapping between used type name indices and their name for (uint32_t i = 0; i < uint32_t(userdataTypes.size()); i++) { - writeByte(bytecode, i + 1); - writeVarInt(bytecode, userdataTypes[i].nameRef); + if (!FFlag::LuauCompileUnusedUdataFix || userdataTypes[i].used) + { + writeByte(bytecode, i + 1); + writeVarInt(bytecode, userdataTypes[i].nameRef); + } } // 0 marks the end of the mapping diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 85e008aa..d4fd1cec 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -27,6 +27,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAG(LuauInterpStringConstFolding) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) + namespace Luau { @@ -2430,6 +2432,11 @@ struct Compiler { compileExprInterpString(interpString, target, targetTemp); } + else if (AstExprInstantiate* expr = node->as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + compileExpr(expr->expr, target, targetTemp); + } else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 0a01f24a..a019eecb 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -7,6 +7,7 @@ #include #include +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) LUAU_FASTFLAGVARIABLE(LuauStringConstFolding2) LUAU_FASTFLAGVARIABLE(LuauInterpStringConstFolding) @@ -623,6 +624,11 @@ struct ConstantVisitor : AstVisitor analyze(expression); } } + else if (AstExprInstantiate* expr = node->as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + result = analyze(expr->expr); + } else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 1b99bf52..a1bbe380 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) + namespace Luau { namespace Compile @@ -213,6 +215,11 @@ struct CostVisitor : AstVisitor return cost; } + else if (AstExprInstantiate* expr = node->as()) + { + LUAU_ASSERT(FFlag::LuauExplicitTypeExpressionInstantiation); + return model(expr->expr); + } else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 4fcd984b..b6559d55 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -757,6 +757,9 @@ struct TypeMapVisitor : AstVisitor recordResolvedType(node, &builtinTypes.stringType); break; + case LBF_MATH_ISNAN: + case LBF_MATH_ISINF: + case LBF_MATH_ISFINITE: case LBF_RAWEQUAL: recordResolvedType(node, &builtinTypes.booleanType); break; diff --git a/Makefile b/Makefile index 72cfbf22..1b24796b 100644 --- a/Makefile +++ b/Makefile @@ -258,11 +258,11 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(COMPILE_CLI_TARGET) $ $(CXX) $^ $(LDFLAGS) -o $@ # executable targets for fuzzing -fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) +fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(COMMON_TARGET) $(CXX) $^ $(LDFLAGS) -o $@ -fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator -fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) $(COMMON_TARGET) | build/libprotobuf-mutator +fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) $(COMMON_TARGET) | build/libprotobuf-mutator # static library targets $(COMMON_TARGET): $(COMMON_OBJECTS) diff --git a/Require/include/Luau/Require.h b/Require/include/Luau/Require.h index a4918fff..584d201f 100644 --- a/Require/include/Luau/Require.h +++ b/Require/include/Luau/Require.h @@ -86,6 +86,11 @@ typedef struct luarequire_Configuration // alias's path cannot be resolved relative to its configuration file. luarequire_NavigateResult (*jump_to_alias)(lua_State* L, void* ctx, const char* path); + // Provides a final override opportunity if an alias cannot be found in + // configuration files. If NAVIGATE_SUCCESS is returned, this must update + // the internal state to point at the aliased module. Can be left undefined. + luarequire_NavigateResult (*to_alias_fallback)(lua_State* L, void* ctx, const char* alias_unprefixed); + // Navigates through the context by making mutations to the internal state. luarequire_NavigateResult (*to_parent)(lua_State* L, void* ctx); luarequire_NavigateResult (*to_child)(lua_State* L, void* ctx, const char* name); diff --git a/Require/include/Luau/RequireNavigator.h b/Require/include/Luau/RequireNavigator.h index d3574326..bf1acad6 100644 --- a/Require/include/Luau/RequireNavigator.h +++ b/Require/include/Luau/RequireNavigator.h @@ -61,6 +61,10 @@ class NavigationContext virtual NavigateResult reset(const std::string& identifier) = 0; virtual NavigateResult jumpToAlias(const std::string& path) = 0; + virtual NavigateResult toAliasFallback(const std::string& aliasUnprefixed) + { + return NavigateResult::NotFound; + }; virtual NavigateResult toParent() = 0; virtual NavigateResult toChild(const std::string& component) = 0; @@ -119,6 +123,7 @@ class Navigator [[nodiscard]] Error jumpToAlias(const std::string& aliasPath); [[nodiscard]] Error navigateToParent(std::optional previousComponent); [[nodiscard]] Error navigateToChild(const std::string& component); + [[nodiscard]] Error toAliasFallback(const std::string& aliasUnprefixed); NavigationContext& navigationContext; ErrorHandler& errorHandler; diff --git a/Require/src/Navigation.cpp b/Require/src/Navigation.cpp index 9438dca9..6bd9498c 100644 --- a/Require/src/Navigation.cpp +++ b/Require/src/Navigation.cpp @@ -74,6 +74,13 @@ NavigationContext::NavigateResult RuntimeNavigationContext::jumpToAlias(const st return convertNavigateResult(config->jump_to_alias(L, ctx, path.c_str())); } +NavigationContext::NavigateResult RuntimeNavigationContext::toAliasFallback(const std::string& aliasUnprefixed) +{ + if (!config->to_alias_fallback) + return NavigationContext::NavigateResult::NotFound; + return convertNavigateResult(config->to_alias_fallback(L, ctx, aliasUnprefixed.c_str())); +} + NavigationContext::NavigateResult RuntimeNavigationContext::toParent() { return convertNavigateResult(config->to_parent(L, ctx)); diff --git a/Require/src/Navigation.h b/Require/src/Navigation.h index ecb90668..6136f29b 100644 --- a/Require/src/Navigation.h +++ b/Require/src/Navigation.h @@ -34,6 +34,7 @@ class RuntimeNavigationContext : public NavigationContext // Navigation interface NavigateResult reset(const std::string& requirerChunkname) override; NavigateResult jumpToAlias(const std::string& path) override; + NavigateResult toAliasFallback(const std::string& aliasUnprefixed) override; NavigateResult toParent() override; NavigateResult toChild(const std::string& component) override; diff --git a/Require/src/RequireNavigator.cpp b/Require/src/RequireNavigator.cpp index 6ec5699a..a8c912a7 100644 --- a/Require/src/RequireNavigator.cpp +++ b/Require/src/RequireNavigator.cpp @@ -81,25 +81,36 @@ Error Navigator::navigateImpl(std::string_view path) if (Error error = navigateToAndPopulateConfig(alias, config)) return error; - if (!config.aliases.contains(alias)) + if (config.aliases.contains(alias)) { - if (alias != "self") - return "@" + alias + " is not a valid alias"; - - // If the alias is "@self", we reset to the requirer's context and - // navigate directly from there. - if (Error error = resetToRequirer()) + if (Error error = navigateToAlias(alias, config, {})) return error; if (Error error = navigateThroughPath(path)) return error; return std::nullopt; } + else + { + if (alias == "self") + { + // If the alias is "@self", we reset to the requirer's context and + // navigate directly from there. + if (Error error = resetToRequirer()) + return error; + if (Error error = navigateThroughPath(path)) + return error; - if (Error error = navigateToAlias(alias, config, {})) - return error; - if (Error error = navigateThroughPath(path)) - return error; + return std::nullopt; + } + + if (Error error = toAliasFallback(alias)) + return error; + if (Error error = navigateThroughPath(path)) + return error; + + return std::nullopt; + } } if (pathType == PathType::RelativeToCurrent || pathType == PathType::RelativeToParent) @@ -150,6 +161,7 @@ Error Navigator::navigateThroughPath(std::string_view path) Error Navigator::navigateToAlias(const std::string& alias, const Config& config, AliasCycleTracker cycleTracker) { + LUAU_ASSERT(config.aliases.contains(alias)); std::string value = config.aliases.find(alias)->value; PathType pathType = getPathType(value); @@ -174,6 +186,7 @@ Error Navigator::navigateToAlias(const std::string& alias, const Config& config, Config parentConfig; if (Error error = navigateToAndPopulateConfig(nextAlias, parentConfig)) return error; + if (parentConfig.aliases.contains(nextAlias)) { if (Error error = navigateToAlias(nextAlias, parentConfig, {})) @@ -181,7 +194,8 @@ Error Navigator::navigateToAlias(const std::string& alias, const Config& config, } else { - return "@" + nextAlias + " is not a valid alias"; + if (Error error = toAliasFallback(nextAlias)) + return error; } } @@ -201,6 +215,8 @@ Error Navigator::navigateToAndPopulateConfig(const std::string& desiredAlias, Co { while (!config.aliases.contains(desiredAlias)) { + config = {}; // Clear existing config data. + NavigationContext::NavigateResult result = navigationContext.toParent(); if (result == NavigationContext::NavigateResult::Ambiguous) return "could not navigate up the ancestry chain during search for alias \"" + desiredAlias + "\" (ambiguous)"; @@ -306,4 +322,16 @@ Error Navigator::navigateToChild(const std::string& component) return errorMessage; } +Error Navigator::toAliasFallback(const std::string& aliasUnprefixed) +{ + NavigationContext::NavigateResult result = navigationContext.toAliasFallback(aliasUnprefixed); + if (result == NavigationContext::NavigateResult::Success) + return std::nullopt; + + std::string errorMessage = "@" + aliasUnprefixed + " is not a valid alias"; + if (result == NavigationContext::NavigateResult::Ambiguous) + errorMessage += " (ambiguous)"; + return errorMessage; +} + } // namespace Luau::Require diff --git a/Sources.cmake b/Sources.cmake index 58f3e205..c52c3604 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -567,6 +567,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.cfa.test.cpp tests/TypeInfer.classes.test.cpp tests/TypeInfer.definitions.test.cpp + tests/TypeInfer.typeInstantiations.test.cpp tests/TypeInfer.functions.test.cpp tests/TypeInfer.generics.test.cpp tests/TypeInfer.intersectionTypes.test.cpp diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index dc00ea27..562945e4 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1746,6 +1746,45 @@ static int luauF_lerp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } +static int luauF_isnan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double x = nvalue(arg0); + + setbvalue(res, isnan(x)); + return 1; + } + + return -1; +} + +static int luauF_isinf(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double x = nvalue(arg0); + + setbvalue(res, isinf(x)); + return 1; + } + + return -1; +} + +static int luauF_isfinite(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double x = nvalue(arg0); + + setbvalue(res, isfinite(x)); + return 1; + } + + return -1; +} + static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { return -1; @@ -1945,6 +1984,10 @@ const luau_FastFunction luauF_table[256] = { luauF_vectorlerp, + luauF_isnan, + luauF_isinf, + luauF_isfinite, + // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index e4769c57..8dd489e8 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -13,6 +13,8 @@ #define PCG32_INC 105 +LUAU_FASTFLAGVARIABLE(LuauMathIsNanInfFinite) + uint32_t pcg32_random(uint64_t* state) { uint64_t oldstate = *state; @@ -442,6 +444,30 @@ static int math_lerp(lua_State* L) return 1; } +static int math_isnan(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + + lua_pushboolean(L, isnan(x)); + return 1; +} + +static int math_isinf(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + + lua_pushboolean(L, isinf(x)); + return 1; +} + +static int math_isfinite(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + + lua_pushboolean(L, isfinite(x)); + return 1; +} + static const luaL_Reg mathlib[] = { {"abs", math_abs}, {"acos", math_acos}, @@ -492,6 +518,17 @@ int luaopen_math(lua_State* L) pcg32_seed(&L->global->rngstate, seed); luaL_register(L, LUA_MATHLIBNAME, mathlib); + + if (FFlag::LuauMathIsNanInfFinite) + { + lua_pushcfunction(L, math_isnan, "isnan"); + lua_setfield(L, -2, "isnan"); + lua_pushcfunction(L, math_isinf, "isinf"); + lua_setfield(L, -2, "isinf"); + lua_pushcfunction(L, math_isfinite, "isfinite"); + lua_setfield(L, -2, "isfinite"); + } + lua_pushnumber(L, PI); lua_setfield(L, -2, "pi"); lua_pushnumber(L, HUGE_VAL); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index af1dce13..a9a1486d 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -26,6 +26,7 @@ LUAU_FASTINT(LuauRecursionLimit) LUAU_FASTFLAG(LuauStringConstFolding2) LUAU_FASTFLAG(LuauCompileTypeofFold) LUAU_FASTFLAG(LuauInterpStringConstFolding) +LUAU_FASTFLAG(LuauCompileMathIsNanInfFinite) using namespace Luau; @@ -2670,6 +2671,9 @@ TEST_CASE("RecursionParse") ScopedFastInt flag(FInt::LuauRecursionLimit, 200); #elif defined(_NOOPT) || defined(_DEBUG) ScopedFastInt flag(FInt::LuauRecursionLimit, 300); + // ServerLua: We do RelWithDebinfo MSVC builds and it does not like large stacks. +#elif defined(_MSC_VER) + ScopedFastInt flag(FInt::LuauRecursionLimit, 300); #endif Luau::BytecodeBuilder bcb; @@ -7904,7 +7908,7 @@ RETURN R1 -1 TEST_CASE("BuiltinFolding") { - ScopedFastFlag luauCompileTypeofFold{FFlag::LuauCompileTypeofFold, true}; + ScopedFastFlag _[]{{FFlag::LuauCompileTypeofFold, true}, {FFlag::LuauCompileMathIsNanInfFinite, true}}; CHECK_EQ( "\n" + compileFunction( @@ -7962,7 +7966,13 @@ return math.log(100, 10), typeof(nil), type(vector.create(1, 0, 0)), - (type("fin")) + (type("fin")), + math.isnan(0/0), + math.isnan(0), + math.isinf(math.huge), + math.isinf(-4), + math.isfinite(42), + math.isfinite(-math.huge) )", 0, 2 @@ -8021,7 +8031,13 @@ LOADN R49 2 LOADK R50 K3 ['nil'] LOADK R51 K4 ['vector'] LOADK R52 K5 ['string'] -RETURN R0 53 +LOADB R53 1 +LOADB R54 0 +LOADB R55 1 +LOADB R56 0 +LOADB R57 1 +LOADB R58 0 +RETURN R0 59 )" ); } diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 279a614e..6a297f89 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -35,12 +35,16 @@ void luaC_validate(lua_State* L); void luau_callhook(lua_State* L, lua_Hook hook, void* userdata); LUAU_FASTFLAG(DebugLuauAbortingChecks) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauVectorLerp) LUAU_FASTFLAG(LuauCompileVectorLerp) LUAU_FASTFLAG(LuauTypeCheckerVectorLerp2) LUAU_FASTFLAG(LuauCodeGenVectorLerp2) LUAU_FASTFLAG(LuauStacklessPcall) +LUAU_FASTFLAG(LuauMathIsNanInfFinite) +LUAU_FASTFLAG(LuauCompileMathIsNanInfFinite) +LUAU_FASTFLAG(LuauTypeCheckerMathIsNanInfFinite) static lua_CompileOptions defaultOptions() { @@ -1131,6 +1135,8 @@ TEST_CASE("Buffers") TEST_CASE("Math") { + ScopedFastFlag _[] = {{FFlag::LuauMathIsNanInfFinite, true}, {FFlag::LuauCompileMathIsNanInfFinite, true}}; + runConformance("math.luau"); } @@ -1345,6 +1351,12 @@ TEST_CASE("Pack") runConformance("tpack.luau"); } +TEST_CASE("ExplicitTypeInstantiations") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + runConformance("explicit_type_instantiations.luau"); +} + int singleYield(lua_State* L) { lua_pushnumber(L, 2); @@ -1740,6 +1752,10 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag _[] = { + {FFlag::LuauMathIsNanInfFinite, true}, {FFlag::LuauCompileMathIsNanInfFinite, true}, {FFlag::LuauTypeCheckerMathIsNanInfFinite, true} + }; + runConformance( "types.luau", [](lua_State* L) diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h index 07c721ba..f3bfca8c 100644 --- a/tests/ConformanceIrHooks.h +++ b/tests/ConformanceIrHooks.h @@ -3,15 +3,17 @@ #include "Luau/IrBuilder.h" -static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; +static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", "vertex", nullptr}; constexpr uint8_t kUserdataExtra = 0; constexpr uint8_t kUserdataColor = 1; constexpr uint8_t kUserdataVec2 = 2; constexpr uint8_t kUserdataMat3 = 3; +constexpr uint8_t kUserdataVertex = 4; // Userdata tags can be different from userdata bytecode type indices constexpr uint8_t kTagVec2 = 12; +constexpr uint8_t kTagVertex = 13; struct Vec2 { @@ -19,6 +21,13 @@ struct Vec2 float y; }; +struct Vertex +{ + float pos[3]; + float normal[3]; + float uv[2]; +}; + inline bool compareMemberName(const char* member, size_t memberLength, const char* str) { return memberLength == strlen(str) && strcmp(member, str) == 0; @@ -227,6 +236,16 @@ inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size if (compareMemberName(member, memberLength, "Row3")) return LBC_TYPE_VECTOR; break; + case kUserdataVertex: + if (compareMemberName(member, memberLength, "pos")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "normal")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "uv")) + return userdataIndexToType(kUserdataVec2); + break; } return LBC_TYPE_ANY; @@ -325,6 +344,54 @@ inline bool userdataAccess( break; case kUserdataMat3: break; + case kUserdataVertex: + if (compareMemberName(member, memberLength, "pos")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVertex), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, pos[0])), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, pos[1])), build.constTag(LUA_TUSERDATA)); + IrOp z = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, pos[2])), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(resultReg), x, y, z); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TVECTOR)); + return true; + } + + if (compareMemberName(member, memberLength, "normal")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVertex), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, normal[0])), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, normal[1])), build.constTag(LUA_TUSERDATA)); + IrOp z = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, normal[2])), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(resultReg), x, y, z); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TVECTOR)); + return true; + } + + if (compareMemberName(member, memberLength, "uv")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVertex), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, uv[0])), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vertex, uv[1])), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::CHECK_GC); + IrOp result = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, result, build.constInt(offsetof(Vec2, x)), x, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, result, build.constInt(offsetof(Vec2, y)), y, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), result); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + return true; + } + break; } return false; diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index 67f54feb..8fdaccb7 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -5,7 +5,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauScopedSeenSetInLookupTableProp); using namespace Luau; @@ -64,8 +63,6 @@ TEST_CASE_FIXTURE(ConstraintGeneratorFixture, "proper_let_generalization") TEST_CASE_FIXTURE(ConstraintGeneratorFixture, "table_prop_access_diamond") { - ScopedFastFlag sff(FFlag::LuauScopedSeenSetInLookupTableProp, true); - CheckResult result = check(R"( export type ItemDetails = { Id: number } diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index a88193df..854fcac9 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -1325,6 +1325,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 JUMP bb_1 +; glued to: bb_1 bb_1: RETURN 1u @@ -1364,6 +1365,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 JUMP bb_2 +; glued to: bb_2 bb_2: RETURN 2u @@ -1399,6 +1401,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") %0 = LOAD_TAG R1 CHECK_TAG %0, tboolean JUMP bb_2 +; glued to: bb_2 bb_2: RETURN 2u @@ -1430,6 +1433,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") bb_0: STORE_INT R1, 5i JUMP bb_1 +; glued to: bb_1 bb_1: RETURN 1u @@ -1461,6 +1465,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") bb_0: STORE_DOUBLE R1, 4 JUMP bb_2 +; glued to: bb_2 bb_2: RETURN 2u @@ -1489,6 +1494,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor bb_0: STORE_TAG R0, tnumber JUMP bb_1 +; glued to: bb_1 bb_1: STORE_TAG R1, tnumber @@ -1558,6 +1564,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") bb_0: STORE_TAG R0, tnumber JUMP bb_1 +; glued to: bb_1 bb_1: RETURN R0, 0i @@ -1596,6 +1603,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") bb_1: STORE_TAG R0, tnumber JUMP bb_2 +; glued to: bb_2 bb_2: RETURN R0, 0i @@ -1634,6 +1642,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: JUMP bb_1 +; glued to: bb_1 bb_1: RETURN R0, 0i @@ -1641,6 +1650,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") bb_2: STORE_TAG R0, tnumber JUMP bb_3 +; glued to: bb_3 bb_3: RETURN R0, 0i @@ -1932,6 +1942,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") %0 = LOAD_TAG R2 CHECK_TAG %0, tnumber, bb_fallback_1 JUMP bb_linear_6 +; glued to: bb_linear_6 bb_fallback_1: DO_LEN R1, R2 @@ -1948,6 +1959,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") bb_4: JUMP bb_5 +; glued to: bb_5 bb_5: RETURN R0, 0i @@ -3769,6 +3781,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagSelfEqualityCheckRemoval") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: JUMP bb_1 +; glued to: bb_1 bb_1: RETURN 1u @@ -4136,6 +4149,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut") STORE_TAG R1, tnumber STORE_DOUBLE R1, 1 JUMP bb_1 +; glued to: bb_1 bb_1: ; predecessors: bb_0 diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 256ce359..8cddf95a 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -20,7 +20,10 @@ LUAU_FASTFLAG(LuauVectorLerp) LUAU_FASTFLAG(LuauCompileVectorLerp) LUAU_FASTFLAG(LuauTypeCheckerVectorLerp2) LUAU_FASTFLAG(LuauCodeGenVectorLerp2) -LUAU_FASTFLAG(LuauCodeGenFMA) +LUAU_FASTFLAG(LuauCompileUnusedUdataFix) +LUAU_FASTFLAG(LuauCodegenFloatLoadStoreProp) +LUAU_FASTFLAG(LuauCodegenBlockSafeEnv) +LUAU_FASTFLAG(LuauCodegenChainLink) static void luauLibraryConstantLookup(const char* library, const char* member, Luau::CompileConstant* constant) { @@ -165,7 +168,7 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = copts.vectorCtor = "vector"; copts.vectorType = "vector"; - static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", "vertex", nullptr}; copts.userdataTypes = kUserdataCompileTypes; static const char* kLibrariesWithConstants[] = {"vector", "Vector3", nullptr}; @@ -478,17 +481,16 @@ TEST_CASE("VectorLerp") {FFlag::LuauCompileVectorLerp, true}, {FFlag::LuauTypeCheckerVectorLerp2, true}, {FFlag::LuauVectorLerp, true}, - {FFlag::LuauCodeGenVectorLerp2, true} + {FFlag::LuauCodeGenVectorLerp2, true}, + {FFlag::LuauCodegenBlockSafeEnv, true} }; - if (FFlag::LuauCodeGenFMA) - { - CHECK_EQ( - "\n" + getCodegenAssembly(R"( + CHECK_EQ( + "\n" + getCodegenAssembly(R"( local function vec3lerp(a: vector, b: vector, t: number) return vector.lerp(a, b, t) end )"), - R"( + R"( ; function vec3lerp($arg0, $arg1, $arg2) line 2 bb_0: CHECK_TAG R0, tvector, exit(entry) @@ -498,7 +500,7 @@ end bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - CHECK_SAFE_ENV exit(2) + implicit CHECK_SAFE_ENV exit(0) %15 = LOAD_TVALUE R0 %16 = LOAD_TVALUE R1 %17 = LOAD_DOUBLE R2 @@ -512,47 +514,13 @@ end INTERRUPT 8u RETURN R3, 1i )" - ); - } - else - { - CHECK_EQ( - "\n" + getCodegenAssembly(R"( -local function vec3lerp(a: vector, b: vector, t: number) - return vector.lerp(a, b, t) -end -)"), - R"( -; function vec3lerp($arg0, $arg1, $arg2) line 2 -bb_0: - CHECK_TAG R0, tvector, exit(entry) - CHECK_TAG R1, tvector, exit(entry) - CHECK_TAG R2, tnumber, exit(entry) - JUMP bb_2 -bb_2: - JUMP bb_bytecode_1 -bb_bytecode_1: - CHECK_SAFE_ENV exit(2) - %15 = LOAD_TVALUE R0 - %16 = LOAD_TVALUE R1 - %17 = LOAD_DOUBLE R2 - %18 = NUM_TO_VEC %17 - %19 = NUM_TO_VEC 1 - %20 = SUB_VEC %16, %15 - %21 = MUL_VEC %20, %18 - %22 = ADD_VEC %15, %21 - SELECT_VEC %22, %16, %18, %19 - %24 = TAG_VECTOR %23 - STORE_TVALUE R3, %24 - INTERRUPT 8u - RETURN R3, 1i -)" - ); - } + ); } TEST_CASE("ExtraMathMemoryOperands") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly(R"( local function foo(a: number, b: number, c: number, d: number, e: number) @@ -571,7 +539,7 @@ end bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - CHECK_SAFE_ENV exit(1) + implicit CHECK_SAFE_ENV exit(0) %16 = FLOOR_NUM R0 %23 = CEIL_NUM R1 %32 = ADD_NUM %16, %23 @@ -625,6 +593,8 @@ end TEST_CASE("DseInitialStackState2") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly(R"( local function foo(a) @@ -635,7 +605,7 @@ end R"( ; function foo($arg0) line 2 bb_bytecode_0: - CHECK_SAFE_ENV exit(1) + implicit CHECK_SAFE_ENV exit(0) CHECK_TAG R0, tnumber, exit(1) FASTCALL 14u, R1, R0, 2i INTERRUPT 5u @@ -729,6 +699,8 @@ end TEST_CASE("BooleanCompare") { + ScopedFastFlag luauCodegenChainLink{FFlag::LuauCodegenChainLink, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -752,23 +724,17 @@ end STORE_INT R2, %7 JUMP bb_bytecode_2 bb_bytecode_2: - %14 = LOAD_TAG R0 - %15 = LOAD_INT R0 - %16 = CMP_SPLIT_TVALUE %14, tboolean, %15, 0i, eq + %16 = CMP_SPLIT_TVALUE %5, tboolean, %6, 0i, eq STORE_TAG R3, tboolean STORE_INT R3, %16 JUMP bb_bytecode_4 bb_bytecode_4: - %23 = LOAD_TAG R0 - %24 = LOAD_INT R0 - %25 = CMP_SPLIT_TVALUE %23, tboolean, %24, 1i, not_eq + %25 = CMP_SPLIT_TVALUE %5, tboolean, %6, 1i, not_eq STORE_TAG R4, tboolean STORE_INT R4, %25 JUMP bb_bytecode_6 bb_bytecode_6: - %32 = LOAD_TAG R0 - %33 = LOAD_INT R0 - %34 = CMP_SPLIT_TVALUE %32, tboolean, %33, 0i, not_eq + %34 = CMP_SPLIT_TVALUE %5, tboolean, %6, 0i, not_eq STORE_TAG R5, tboolean STORE_INT R5, %34 JUMP bb_bytecode_8 @@ -782,6 +748,8 @@ end TEST_CASE("NumberCompare") { + ScopedFastFlag luauCodegenChainLink{FFlag::LuauCodegenChainLink, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -805,9 +773,7 @@ end STORE_INT R2, %7 JUMP bb_bytecode_2 bb_bytecode_2: - %14 = LOAD_TAG R0 - %15 = LOAD_DOUBLE R0 - %16 = CMP_SPLIT_TVALUE %14, tnumber, %15, 3, not_eq + %16 = CMP_SPLIT_TVALUE %5, tnumber, %6, 3, not_eq STORE_TAG R3, tboolean STORE_INT R3, %16 JUMP bb_bytecode_4 @@ -821,6 +787,8 @@ end TEST_CASE("TypeCompare") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -832,7 +800,7 @@ end R"( ; function foo($arg0) line 2 bb_bytecode_0: - CHECK_SAFE_ENV exit(1) + implicit CHECK_SAFE_ENV exit(0) %1 = LOAD_TAG R0 %2 = GET_TYPE %1 STORE_POINTER R2, %2 @@ -850,6 +818,8 @@ end TEST_CASE("TypeofCompare") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -861,7 +831,7 @@ end R"( ; function foo($arg0) line 2 bb_bytecode_0: - CHECK_SAFE_ENV exit(1) + implicit CHECK_SAFE_ENV exit(0) %1 = GET_TYPEOF R0 STORE_POINTER R2, %1 STORE_TAG R2, tstring @@ -878,6 +848,8 @@ end TEST_CASE("TypeofCompareCustom") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -889,7 +861,7 @@ end R"( ; function foo($arg0) line 2 bb_bytecode_0: - CHECK_SAFE_ENV exit(1) + implicit CHECK_SAFE_ENV exit(0) %1 = GET_TYPEOF R0 STORE_POINTER R2, %1 STORE_TAG R2, tstring @@ -907,6 +879,11 @@ end TEST_CASE("TypeCondition") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + ScopedFastFlag luauCodegenChainLink{FFlag::LuauCodegenChainLink, true}; + + // TODO: opportunity 1 - first store to R2 is dead, but dead store op doesn't go through glued chains yet + // TODO: opportunity 2 - bb_4 already made sure %1 == R0.tag is a number, check in bb_3 can be removed CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -921,16 +898,14 @@ end R"( ; function foo($arg0, $arg1) line 2 bb_bytecode_0: - CHECK_SAFE_ENV exit(1) + implicit CHECK_SAFE_ENV exit(0) %1 = LOAD_TAG R0 %2 = GET_TYPE %1 STORE_POINTER R2, %2 STORE_TAG R2, tstring JUMP bb_4 bb_4: - %7 = LOAD_POINTER R2 - %8 = LOAD_POINTER K2 ('number') - JUMP_EQ_POINTER %7, %8, bb_3, bb_bytecode_1 + JUMP_EQ_TAG %1, tnumber, bb_3, bb_bytecode_1 bb_3: CHECK_TAG R0, tnumber, bb_fallback_5 CHECK_TAG R1, tnumber, bb_fallback_5 @@ -950,8 +925,67 @@ end ); } +TEST_CASE("TypeCondition2") +{ + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + ScopedFastFlag luauCodegenChainLink{FFlag::LuauCodegenChainLink, true}; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a, b) + if type(a) == "number" and type(b) == "number" then + return a + b + end + return nil +end +)" + ), + R"( +; function foo($arg0, $arg1) line 2 +bb_bytecode_0: + implicit CHECK_SAFE_ENV exit(0) + %1 = LOAD_TAG R0 + %2 = GET_TYPE %1 + STORE_POINTER R2, %2 + STORE_TAG R2, tstring + JUMP bb_4 +bb_4: + JUMP_EQ_TAG %1, tnumber, bb_3, bb_bytecode_1 +bb_3: + CHECK_SAFE_ENV exit(8) + %11 = LOAD_TAG R1 + %12 = GET_TYPE %11 + STORE_POINTER R2, %12 + STORE_TAG R2, tstring + JUMP bb_7 +bb_7: + JUMP_EQ_TAG %11, tnumber, bb_6, bb_bytecode_1 +bb_6: + CHECK_TAG R0, tnumber, bb_fallback_8 + CHECK_TAG R1, tnumber, bb_fallback_8 + %24 = LOAD_DOUBLE R0 + %26 = ADD_NUM %24, R1 + STORE_DOUBLE R2, %26 + STORE_TAG R2, tnumber + JUMP bb_9 +bb_9: + INTERRUPT 15u + RETURN R2, 1i +bb_bytecode_1: + STORE_TAG R2, tnil + INTERRUPT 17u + RETURN R2, 1i +)" + ); +} + TEST_CASE("AssertTypeGuard") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + ScopedFastFlag luauCodegenChainLink{FFlag::LuauCodegenChainLink, true}; + + // TODO: opportunity - CHECK_TRUTHY indirectly establishes that %1 is a number for CHECK_TAG in bb_5 CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -964,7 +998,7 @@ end R"( ; function foo($arg0) line 2 bb_bytecode_0: - CHECK_SAFE_ENV exit(1) + implicit CHECK_SAFE_ENV exit(0) %1 = LOAD_TAG R0 %2 = GET_TYPE %1 STORE_POINTER R3, %2 @@ -974,10 +1008,10 @@ end STORE_INT R2, %8 JUMP bb_bytecode_2 bb_bytecode_2: - CHECK_TRUTHY tboolean, R2, exit(10) + CHECK_TRUTHY tboolean, %8, exit(10) JUMP bb_5 bb_5: - CHECK_TAG R0, tnumber, bb_fallback_6 + CHECK_TAG %1, tnumber, bb_fallback_6 %28 = LOAD_DOUBLE R0 %29 = ADD_NUM %28, %28 STORE_DOUBLE R1, %29 @@ -1490,6 +1524,8 @@ end #if LUA_VECTOR_SIZE == 3 TEST_CASE("FastcallTypeInferThroughLocal") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -1508,11 +1544,11 @@ end ; function getsum($arg0, $arg1) line 2 ; R2: vector from 0 to 18 bb_bytecode_0: + implicit CHECK_SAFE_ENV exit(0) STORE_DOUBLE R4, 2 STORE_TAG R4, tnumber STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber - CHECK_SAFE_ENV exit(4) CHECK_TAG R0, tnumber, exit(4) %11 = LOAD_DOUBLE R0 STORE_VECTOR R2, %11, 2, 3 @@ -1540,6 +1576,8 @@ end TEST_CASE("FastcallTypeInferThroughUpvalue") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -1560,11 +1598,11 @@ end ; function getsum($arg0, $arg1) line 4 ; U0: vector bb_bytecode_0: + implicit CHECK_SAFE_ENV exit(0) STORE_DOUBLE R4, 2 STORE_TAG R4, tnumber STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber - CHECK_SAFE_ENV exit(4) CHECK_TAG R0, tnumber, exit(4) %11 = LOAD_DOUBLE R0 STORE_VECTOR R2, %11, 2, 3 @@ -1670,6 +1708,8 @@ end #if LUA_VECTOR_SIZE == 3 TEST_CASE("ArgumentTypeRefinement") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -1684,11 +1724,11 @@ end ; function getsum($arg0, $arg1) line 2 ; R0: vector [argument] bb_bytecode_0: + implicit CHECK_SAFE_ENV exit(0) STORE_DOUBLE R3, 1 STORE_TAG R3, tnumber STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber - CHECK_SAFE_ENV exit(4) CHECK_TAG R1, tnumber, exit(4) %12 = LOAD_DOUBLE R1 STORE_VECTOR R2, 1, %12, 3 @@ -1709,6 +1749,8 @@ end TEST_CASE("InlineFunctionType") { + ScopedFastFlag luauCodegenFloatLoadStoreProp{FFlag::LuauCodegenFloatLoadStoreProp, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -1746,8 +1788,7 @@ end CHECK_TAG R0, tvector, exit(0) %2 = LOAD_FLOAT R0, 4i %8 = MUL_NUM %2, 3 - %13 = LOAD_FLOAT R0, 4i - %19 = MUL_NUM %13, 5 + %19 = MUL_NUM %2, 5 %28 = ADD_NUM %8, %19 STORE_DOUBLE R1, %28 STORE_TAG R1, tnumber @@ -2002,6 +2043,8 @@ end TEST_CASE("ForInManualAnnotation") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly( R"( @@ -2031,9 +2074,9 @@ end bb_4: JUMP bb_bytecode_1 bb_bytecode_1: + implicit CHECK_SAFE_ENV exit(0) STORE_DOUBLE R1, 0 STORE_TAG R1, tnumber - CHECK_SAFE_ENV exit(1) GET_CACHED_IMPORT R2, K1 (nil), 1073741824u ('ipairs'), 2u %8 = LOAD_TVALUE R0 STORE_TVALUE R3, %8 @@ -2539,6 +2582,54 @@ end ); } +TEST_CASE("CustomUserdataMapping") +{ + ScopedFastFlag luauCompileUnusedUdataFix{FFlag::LuauCompileUnusedUdataFix, true}; + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: mat3) + print(a, vec2.create(0, 0)) +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + implicit CHECK_SAFE_ENV exit(0) + GET_CACHED_IMPORT R1, K1 (nil), 1073741824u ('print'), 1u + %6 = LOAD_TVALUE R0 + STORE_TVALUE R2, %6 + GET_CACHED_IMPORT R3, K4 (nil), 2149583872u ('vec2'.'create'), 4u + STORE_DOUBLE R4, 0 + STORE_TAG R4, tnumber + STORE_DOUBLE R5, 0 + STORE_TAG R5, tnumber + INTERRUPT 7u + SET_SAVEDPC 8u + CALL R3, 2i, -1i + INTERRUPT 8u + SET_SAVEDPC 9u + CALL R1, -1i, 0i + INTERRUPT 9u + RETURN R0, 0i +)" + ); +} + TEST_CASE("LibraryFieldTypesAndConstants") { CHECK_EQ( @@ -2640,6 +2731,8 @@ end TEST_CASE("Bit32BtestDirect") { + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + CHECK_EQ( "\n" + getCodegenAssembly(R"( local function foo(a: number) @@ -2654,7 +2747,7 @@ end bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - CHECK_SAFE_ENV exit(2) + implicit CHECK_SAFE_ENV exit(0) %7 = LOAD_DOUBLE R0 %8 = NUM_TO_UINT %7 %10 = BITAND_UINT %8, 31i @@ -2667,4 +2760,227 @@ end ); } +TEST_CASE("VectorLoadReuse") +{ + ScopedFastFlag luauCodegenFloatLoadStoreProp{FFlag::LuauCodegenFloatLoadStoreProp, true}; + + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function shuffle(v: vector) + return v.x * v.x + v.y * v.y +end +)"), + R"( +; function shuffle($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_FLOAT R0, 0i + %20 = MUL_NUM %6, %6 + %25 = LOAD_FLOAT R0, 4i + %39 = MUL_NUM %25, %25 + %48 = ADD_NUM %20, %39 + STORE_DOUBLE R1, %48 + STORE_TAG R1, tnumber + INTERRUPT 11u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorShuffle1") +{ + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + + // TODO: opportunity - if we introduce a separate vector shuffle instruction, this can be done in a single shuffle (+/- load and store) + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function shuffle(v: vector) + return vector.create(v.z, v.x, v.y) +end +)"), + R"( +; function shuffle($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + implicit CHECK_SAFE_ENV exit(0) + %6 = LOAD_FLOAT R0, 8i + %11 = LOAD_FLOAT R0, 0i + %16 = LOAD_FLOAT R0, 4i + STORE_VECTOR R1, %6, %11, %16 + STORE_TAG R1, tvector + INTERRUPT 10u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorShuffle2") +{ + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + ScopedFastFlag luauCodegenFloatLoadStoreProp{FFlag::LuauCodegenFloatLoadStoreProp, true}; + + // TODO: opportunity - LOAD_FLOAT performs float->double conversion and STORE_VECTOR immediately performs double->float which should be skipped + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function crossshuffle(v: vector, t: vector) + local tmp1 = vector.create(v.x, v.x, v.z) + local tmp2 = vector.create(t.y, t.z, t.x) + return vector.create(tmp1.z, tmp2.x, tmp1.y) +end +)"), + R"( +; function crossshuffle($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + implicit CHECK_SAFE_ENV exit(0) + %8 = LOAD_FLOAT R0, 0i + %18 = LOAD_FLOAT R0, 8i + %35 = LOAD_FLOAT R1, 4i + STORE_VECTOR R4, %18, %35, %8, tvector + INTERRUPT 30u + RETURN R4, 1i +)" + ); +} + +TEST_CASE("VectorShuffleFromComposite1") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag luauCodegenFloatLoadStoreProp{FFlag::LuauCodegenFloatLoadStoreProp, true}; + + // TODO: opportunity - buffer memory load-store propagation can remove duplicate loads + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function test(v: vertex) + return v.normal.X * v.normal.X + v.normal.Y * v.normal.Y +end +)"), + R"( +; function test($arg0) line 2 +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %6, 13i, exit(0) + %8 = BUFFER_READF32 %6, 12i, tuserdata + %22 = BUFFER_READF32 %6, 12i, tuserdata + %38 = MUL_NUM %8, %22 + %46 = BUFFER_READF32 %6, 16i, tuserdata + %60 = BUFFER_READF32 %6, 16i, tuserdata + %75 = MUL_NUM %46, %60 + %84 = ADD_NUM %38, %75 + STORE_DOUBLE R1, %84 + STORE_TAG R1, tnumber + INTERRUPT 19u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("VectorShuffleFromComposite2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + // TODO: opportunity - userdata memory load-store propagation can remove loads from values that have just been stored + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function test(v: vertex) + return v.uv.X * v.uv.Y +end +)"), + R"( +; function test($arg0) line 2 +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %6, 13i, exit(0) + %8 = BUFFER_READF32 %6, 24i, tuserdata + %9 = BUFFER_READF32 %6, 28i, tuserdata + CHECK_GC + %11 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %11, 0i, %8, tuserdata + BUFFER_WRITEF32 %11, 4i, %9, tuserdata + %20 = BUFFER_READF32 %11, 0i, tuserdata + %27 = BUFFER_READF32 %6, 24i, tuserdata + %28 = BUFFER_READF32 %6, 28i, tuserdata + %30 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %30, 0i, %27, tuserdata + BUFFER_WRITEF32 %30, 4i, %28, tuserdata + STORE_POINTER R4, %30 + STORE_TAG R4, tuserdata + %39 = BUFFER_READF32 %30, 4i, tuserdata + %48 = MUL_NUM %20, %39 + STORE_DOUBLE R1, %48 + STORE_TAG R1, tnumber + INTERRUPT 9u + RETURN R1, 1i +)" + ); +} + +// Vectors use float storage, so in some cases it is impossible to forward value that was passed to the constructor as it was double->float truncated +TEST_CASE("VectorLoadStoreOnlySamePrecision") +{ + ScopedFastFlag luauCodegenBlockSafeEnv{FFlag::LuauCodegenBlockSafeEnv, true}; + + // TODO: opportunity - LOAD_FLOAT can be replaced with a new NUM_TO_FLOAT instruction which will only handle the truncation + CHECK_EQ( + "\n" + getCodegenAssembly(R"( +local function test(x: number, y: number) + local vec = vector.create(x, y, 0) + return vec.X + vec.Y + vec.Z +end +)"), + R"( +; function test($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tnumber, exit(entry) + CHECK_TAG R1, tnumber, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + implicit CHECK_SAFE_ENV exit(0) + %15 = LOAD_DOUBLE R0 + %16 = LOAD_DOUBLE R1 + STORE_VECTOR R2, %15, %16, 0 + STORE_TAG R2, tvector + %22 = LOAD_FLOAT R2, 0i + %27 = LOAD_FLOAT R2, 4i + %36 = ADD_NUM %22, %27 + %41 = LOAD_FLOAT R2, 8i + %50 = ADD_NUM %36, %41 + STORE_DOUBLE R3, %50 + STORE_TAG R3, tnumber + INTERRUPT 16u + RETURN R3, 1i +)" + ); +} + TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index 5e45c801..9e050597 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,8 +15,13 @@ #include "doctest.h" #include +LUAU_DYNAMIC_FASTINT(LuauConstraintGeneratorRecursionLimit) + +LUAU_FASTINT(LuauNonStrictTypeCheckerRecursionLimit) +LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTFLAG(LuauUnreducedTypeFunctionsDontTriggerWarnings) LUAU_FASTFLAG(LuauNewNonStrictBetterCheckedFunctionErrorMessage) +LUAU_FASTFLAG(LuauAddRecursionCounterToNonStrictTypeChecker) using namespace Luau; @@ -853,4 +858,38 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_check_block_recursion_limit") +{ + int limit = 250; + + ScopedFastFlag sff{FFlag::LuauAddRecursionCounterToNonStrictTypeChecker, true}; + + ScopedFastInt luauNonStrictTypeCheckerRecursionLimit{FInt::LuauNonStrictTypeCheckerRecursionLimit, limit - 100}; + ScopedFastInt luauConstraintGeneratorRecursionLimit{DFInt::LuauConstraintGeneratorRecursionLimit, limit + 500}; + ScopedFastInt luauCheckRecursionLimit{FInt::LuauCheckRecursionLimit, limit + 500}; + + CheckResult result = checkNonStrict(rep("do ", limit) + "local a = 1" + rep(" end", limit)); + + // Nonstrict recursion limit just exits early and doesn't produce an error + LUAU_REQUIRE_NO_ERRORS(result); +} + +#if 0 // CLI-181303 requires a ConstraintGenerator::checkPack fix to succeed in debug on Windows +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_check_expr_recursion_limit") +{ + int limit = 250; + + ScopedFastFlag sff{FFlag::LuauAddRecursionCounterToNonStrictTypeChecker, true}; + + ScopedFastInt luauNonStrictTypeCheckerRecursionLimit{FInt::LuauNonStrictTypeCheckerRecursionLimit, limit - 100}; + ScopedFastInt luauConstraintGeneratorRecursionLimit{DFInt::LuauConstraintGeneratorRecursionLimit, limit + 500}; + ScopedFastInt luauCheckRecursionLimit{FInt::LuauCheckRecursionLimit, limit + 500}; + + CheckResult result = checkNonStrict(R"(("foo"))" + rep(":lower()", limit)); + + // Nonstrict recursion limit just exits early and doesn't produce an error + LUAU_REQUIRE_NO_ERRORS(result); +} +#endif + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index e275a5ab..3b69548d 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -18,6 +18,7 @@ LUAU_FASTINT(LuauTypeLengthLimit) LUAU_FASTINT(LuauParseErrorLimit) LUAU_FASTFLAG(LuauSolverV2) LUAU_DYNAMIC_FASTFLAG(DebugLuauReportReturnTypeVariadicWithTypeSuffix) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) // Clip with DebugLuauReportReturnTypeVariadicWithTypeSuffix extern bool luau_telemetry_parsed_return_type_variadic_with_type_suffix; @@ -40,6 +41,39 @@ struct Counter int Counter::instanceCount = 0; +std::string_view stringAtLocation(std::string_view source, const Location& location) +{ + std::vector lines = Luau::split(source, '\n'); + LUAU_ASSERT(lines.size() > location.begin.line && lines.size() > location.end.line); + + int byteStart = -1; + int byteEnd = -1; + int bytesSum = 0; + + for (size_t lineNo = 0; lineNo < lines.size(); ++lineNo) + { + std::string_view line = lines.at(lineNo); + + if (lineNo == location.begin.line) + { + byteStart = bytesSum + location.begin.column; + } + + if (lineNo == location.end.line) + { + byteEnd = bytesSum + location.end.column; + break; + } + + bytesSum += static_cast(line.size()) + 1; + } + + LUAU_ASSERT(byteStart != -1); + LUAU_ASSERT(byteEnd != -1); + + return source.substr(byteStart, byteEnd - byteStart); +} + } // namespace TEST_SUITE_BEGIN("AllocatorTests"); @@ -108,9 +142,11 @@ TEST_CASE_FIXTURE(Fixture, "can_haz_annotations") TEST_CASE_FIXTURE(Fixture, "local_with_annotation") { - AstStatBlock* block = parse(R"( + std::string code = R"( local foo: string = "Hello Types!" - )"); + )"; + + AstStatBlock* block = parse(code); REQUIRE(block != nullptr); @@ -125,6 +161,8 @@ TEST_CASE_FIXTURE(Fixture, "local_with_annotation") REQUIRE(l->annotation != nullptr); REQUIRE_EQ(1, local->values.size); + + REQUIRE_EQ(stringAtLocation(code, l->location), "foo"); } TEST_CASE_FIXTURE(Fixture, "type_names_can_contain_dots") @@ -2807,6 +2845,86 @@ TEST_CASE_FIXTURE(Fixture, "for_loop_with_single_var_has_comma_positions_of_size CHECK_EQ(cstNode->varsCommaPositions.size, 0); } +TEST_CASE_FIXTURE(Fixture, "explicit_type_instantiation_expression_call") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + std::string source = "local x = f<>()"; + + ParseResult result = parseEx(source); + REQUIRE(result.root); + + AstStatLocal* local = result.root->body.data[0]->as(); + REQUIRE(local != nullptr); + + REQUIRE_EQ(1, local->vars.size); + + AstExpr* expr = local->values.data[0]; + REQUIRE(expr != nullptr); + + AstExprCall* call = expr->as(); + REQUIRE(call != nullptr); + + AstExprInstantiate* explicitTypeInstantiation = call->func->as(); + REQUIRE(explicitTypeInstantiation != nullptr); + + REQUIRE_EQ(stringAtLocation(source, explicitTypeInstantiation->location), "f<>"); +} + +TEST_CASE_FIXTURE(Fixture, "explicit_type_instantiation_expression") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + AstStat* stat = parse("local x = f<>"); + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "explicit_type_instantiation_statement") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + AstStat* stat = parse("f<>()"); + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "explicit_type_instantiation_indexing") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + AstStat* stat = parse(R"( + t.f<>() + t:f<>() + t["f"]<>() + )"); + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "explicit_type_instantiation_empty_list") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + AstStat* stat = parse(R"( + f<<>>() + )"); + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "two_left_and_right_arrows_but_no_explicit_type_instantiation") +{ + AstStat* stat = parse(R"( + type A = C() -> T>> + )"); + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "basic_less_than_check_no_explicit_type_instantiaton") +{ + AstStat* stat = parse(R"( + local a = b.c < d + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); diff --git a/tests/PrettyPrinter.test.cpp b/tests/PrettyPrinter.test.cpp index ba842386..59d2c7dd 100644 --- a/tests/PrettyPrinter.test.cpp +++ b/tests/PrettyPrinter.test.cpp @@ -10,6 +10,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) + using namespace Luau; TEST_SUITE_BEGIN("PrettyPrinterTests"); @@ -2127,4 +2129,27 @@ TEST_CASE("prettyPrint_function_attributes") CHECK_EQ(code, prettyPrint(code, {}, true).code); } +TEST_CASE("transpile_explicit_type_instantiations") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + std::string code = "f<>() t.f<>() t:f<>()"; + CHECK_EQ(code, prettyPrint(code, {}, true).code); + + Allocator allocator; + AstNameTable names = {allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + REQUIRE(parseResult.errors.empty()); + CHECK_EQ(code, prettyPrintWithTypes(*parseResult.root)); + + // No types + CHECK_EQ( + "f () t.f () t:f ()", + prettyPrint(code).code + ); + + code = "f < < A , B , C... > >( ) t.f < < A, B, C... > > ( ) t:f< < A, B, C > > ( )"; + CHECK_EQ(code, prettyPrint(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index be1bade7..891a8d43 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -912,4 +912,20 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireChainedAliasesFailureMissing") } } +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireChainedAliasesFailureDependOnInnerAlias") +{ + { + std::string path = getLuauDirectory(PathType::Relative) + + "/tests/require/config_tests/with_config/chained_aliases/subdirectory/failing_requirer_inner_dependency"; + runProtectedRequire(path); + assertOutputContainsAll({"false", "error requiring module \"@dependoninner\": @passthroughinner is not a valid alias"}); + } + { + std::string path = getLuauDirectory(PathType::Relative) + + "/tests/require/config_tests/with_config_luau/chained_aliases/subdirectory/failing_requirer_inner_dependency"; + runProtectedRequire(path); + assertOutputContainsAll({"false", "error requiring module \"@dependoninner\": @passthroughinner is not a valid alias"}); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index f925afef..626ad846 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAG(LuauIntersectNotNil) LUAU_FASTFLAG(DebugLuauAssertOnForcedConstraint) LUAU_FASTFLAG(LuauSubtypingReportGenericBoundMismatches2) LUAU_FASTFLAG(DebugLuauStringSingletonBasedOnQuotes) -LUAU_FASTFLAG(LuauSubtypingUnionsAndIntersectionsInGenericBounds) LUAU_FASTFLAG(LuauUseTopTableForTableClearAndIsFrozen) LUAU_FASTFLAG(LuauEGFixGenericsList) LUAU_FASTFLAG(LuauIncludeExplicitGenericPacks) @@ -2047,11 +2046,8 @@ xpcall(v, print, x) TEST_CASE_FIXTURE(Fixture, "array_of_singletons_should_subtype_against_generic_array") { ScopedFastFlag _[] = { - // These flags expose the issue {FFlag::LuauSubtypingReportGenericBoundMismatches2, true}, {FFlag::DebugLuauStringSingletonBasedOnQuotes, true}, - // And this flag fixes it - {FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds, true} }; CheckResult res = check(R"( local function a(arr: { T }) end @@ -2064,10 +2060,7 @@ TEST_CASE_FIXTURE(Fixture, "array_of_singletons_should_subtype_against_generic_a TEST_CASE_FIXTURE(BuiltinsFixture, "gh1985_array_of_union_for_generic") { - ScopedFastFlag _[] = { - {FFlag::LuauSubtypingReportGenericBoundMismatches2, true}, - {FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds, true} - }; + ScopedFastFlag sff{FFlag::LuauSubtypingReportGenericBoundMismatches2, true}; CheckResult res = check(R"( local function clear(arr: { T }) table.clear(arr) end @@ -2082,10 +2075,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gh1985_array_of_union_for_generic") TEST_CASE_FIXTURE(BuiltinsFixture, "gh1985_array_of_union_for_generic_2") { - ScopedFastFlag _[] = { - {FFlag::LuauSubtypingReportGenericBoundMismatches2, true}, - {FFlag::LuauSubtypingUnionsAndIntersectionsInGenericBounds, true} - }; + ScopedFastFlag sff{FFlag::LuauSubtypingReportGenericBoundMismatches2, true}; CheckResult res = check(R"( local function id(arr: { T }): { T } return arr end diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f22812ab..a9185293 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -11,7 +11,7 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauFunctionCallsAreNotNilable) -LUAU_FASTFLAG(LuauRefineNoRefineAlways) +LUAU_FASTFLAG(LuauRefineNoRefineAlways2) LUAU_FASTFLAG(LuauRefineDistributesOverUnions) LUAU_FASTFLAG(LuauSubtypingReportGenericBoundMismatches2) LUAU_FASTFLAG(LuauNoMoreComparisonTypeFunctions) @@ -22,6 +22,7 @@ LUAU_FASTFLAG(LuauAddRefinementToAssertions) LUAU_FASTFLAG(LuauEnqueueUnionsOfDistributedTypeFunctions) LUAU_FASTFLAG(DebugLuauAssertOnForcedConstraint) LUAU_FASTFLAG(LuauNormalizationPreservesAny) +LUAU_FASTFLAG(LuauRefineNoRefineAlways2) using namespace Luau; @@ -2853,7 +2854,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_by_no_refine_should_always_reduce") // generalization. ScopedFastFlag sffs[] = { {FFlag::LuauSolverV2, true}, - {FFlag::LuauRefineNoRefineAlways, true}, + {FFlag::LuauRefineNoRefineAlways2, true}, }; CheckResult result = check(R"( @@ -3205,4 +3206,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_any_and_unknown_should_still_be_any") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_181100_fast_track_refinement_against_unknown") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauRefineNoRefineAlways2, true}, + {FFlag::DebugLuauAssertOnForcedConstraint, true}, + }; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + --!strict + + local Class = {} + Class.__index = Class + + type Class = setmetatable<{ A: number }, typeof(Class)> + + function Class.Foo(x: Class, y: Class, z: Class) + if y == z then + return + end + local bar = y.A + print(bar) + end + )")); + + CHECK_EQ("number", toString(requireTypeAtPosition({13, 19}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index e6dd05e7..6e665c48 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -35,6 +35,7 @@ LUAU_FASTFLAG(LuauNewOverloadResolver) LUAU_FASTFLAG(LuauGetmetatableError) LUAU_FASTFLAG(LuauSuppressIndexingIntoError) LUAU_FASTFLAG(LuauPushTypeConstriantAlwaysCompletes) +LUAU_FASTFLAG(LuauMarkUnscopedGenericsAsSolved) TEST_SUITE_BEGIN("TableTests"); @@ -6458,12 +6459,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oss_1684") CHECK(get(e)); } -TEST_CASE_FIXTURE(BuiltinsFixture, "push_type_constraint_should_always_complete") +TEST_CASE_FIXTURE(BuiltinsFixture, "oss_2094_push_type_constraint_should_always_complete") { ScopedFastFlag sffs[] = { {FFlag::LuauSolverV2, true}, {FFlag::LuauPushTypeConstraint2, true}, {FFlag::LuauPushTypeConstriantAlwaysCompletes, true}, + {FFlag::LuauMarkUnscopedGenericsAsSolved, true}, + {FFlag::DebugLuauAssertOnForcedConstraint, true}, }; LUAU_REQUIRE_NO_ERRORS(check(R"( @@ -6481,6 +6484,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "push_type_constraint_should_always_complete" } end )")); + } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6b06132f..7242e35e 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -29,7 +29,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauDfgAllowUpdatesInLoops) LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(LuauMissingFollowMappedGenericPacks) -LUAU_FASTFLAG(LuauOccursCheckInCommit) LUAU_FASTFLAG(LuauEGFixGenericsList) LUAU_FASTFLAG(LuauTryToOptimizeSetTypeUnification) LUAU_FASTFLAG(LuauDontReferenceScopePtrFromHashTable) @@ -2571,8 +2570,7 @@ do end #if 0 // CLI-166473: re-enable after flakiness is resolved TEST_CASE_FIXTURE(Fixture, "txnlog_checks_for_occurrence_before_self_binding_a_type") { - ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, false}, {FFlag::LuauOccursCheckInCommit, true}}; - + ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, false}}; CheckResult result = check(R"( local any = nil :: any diff --git a/tests/TypeInfer.typeInstantiations.test.cpp b/tests/TypeInfer.typeInstantiations.test.cpp new file mode 100644 index 00000000..96b062e4 --- /dev/null +++ b/tests/TypeInfer.typeInstantiations.test.cpp @@ -0,0 +1,519 @@ +#include "Fixture.h" + +#include "ScopedFlags.h" +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauExplicitTypeExpressionInstantiation) + +TEST_SUITE_BEGIN("TypeInferExplicitTypeInstantiations"); + +#define SUBCASE_BOTH_SOLVERS() \ + for (bool enabled : {true, false}) \ + if (ScopedFastFlag sffSolver{FFlag::LuauSolverV2, enabled}; true) \ + SUBCASE(enabled ? "New solver" : "Old solver") + +TEST_CASE_FIXTURE(Fixture, "as_expression_correct") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(): T + return nil :: any + end + + local correct = f<>() + 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(Fixture, "as_expression_incorrect") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(): T + return nil :: any + end + + local incorrect = f<>() + 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauSolverV2) + { + REQUIRE_EQ(toString(result.errors[0]), "Operator '+' could not be applied to operands of types string and number; there is no corresponding overload for __add"); + } + else + { + REQUIRE_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + } + } +} + +TEST_CASE_FIXTURE(Fixture, "as_stmt_correct") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(a: T, b: T) + return nil :: any + end + + f<>(1, "a") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(Fixture, "as_stmt_incorrect") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(a: T, b: T) + return nil :: any + end + + f<>(1, "a") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::LuauSolverV2) + { + REQUIRE_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'boolean | number'"); + } + else + { + REQUIRE_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'boolean | number'; none of the union options are compatible"); + } + } +} + +TEST_CASE_FIXTURE(Fixture, "multiple_calls") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(): T + return nil :: any + end + + local a: number = f<>() + local b: string = f<>() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(Fixture, "anonymous_type_inferred") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(): { a: T, b: U } + return nil :: any + end + + local correct: { a: number, b: string } = f<>() + local incorrect: { a: number, b: string } = f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE_EQ(result.errors[0].location.begin.line, 7); + LUAU_REQUIRE_ERROR(result, TypeMismatch); + } +} + +TEST_CASE_FIXTURE(Fixture, "type_packs") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + // FIXME: This triggers a GenericTypePackCountMismatch error, and it's not obvious if the + // code for explicit types is broken, or if subtyping is broken. + ScopedFastFlag oldSolver{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + --!strict + local function f(...: T...): U... end + + local a: number, b: string = f<<(boolean, {}), (number, string)>>(true, {}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_method") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + // FIXME: This triggers a GenericTypePackCountMismatch error, and it's not obvious if the + // code for explicit types is broken, or if subtyping is broken. + ScopedFastFlag oldSolver{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + --!strict + local t: { + f: (self: any, T...) -> U..., + } = nil :: any + + local a: number, b: string = t:f<<(boolean, {}), (number, string)>>(true, {}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_incorrect") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + // FIXME: This triggers a GenericTypePackCountMismatch error, and it's not obvious if the + // code for explicit types is broken, or if subtyping is broken. + ScopedFastFlag oldSolver{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + --!strict + local function f(...: T...): U... end + + local a: number, b: string = f<<(boolean, {}), (number, string)>>(true, "uh oh") + )"); + + LUAU_REQUIRE_ERROR(result, TypeMismatch); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_incorrect_method") +{ + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + // FIXME: This triggers a GenericTypePackCountMismatch error, and it's not obvious if the + // code for explicit types is broken, or if subtyping is broken. + ScopedFastFlag oldSolver{FFlag::LuauSolverV2, false}; + + CheckResult result = check(R"( + --!strict + local t: { + f: (self: any, T...) -> U..., + } = nil :: any + + local a: number, b: string = t:f<<(boolean, {}), (number, string)>>(true, "uh oh") + )"); + + LUAU_REQUIRE_ERROR(result, TypeMismatch); +} + +TEST_CASE_FIXTURE(Fixture, "dot_index_call") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local t = { + f = function(): T + return nil :: any + end, + } + + local correct: number = t.f<>() + local incorrect: number = t.f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE_EQ(result.errors[0].location.begin.line, 9); + } +} + +TEST_CASE_FIXTURE(Fixture, "method_index_call") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local t = { + f = function(self: any): T + return nil :: any + end, + } + + local correct: number = t:f<>() + local incorrect: number = t:f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeMismatch); + REQUIRE_EQ(result.errors[0].location.begin.line, 9); + } +} + +TEST_CASE_FIXTURE(Fixture, "stored_as_variable") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(): T + return nil :: any + end + + local fNumber = f<> + + local correct: number = fNumber() + local incorrect: string = fNumber() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeMismatch); + REQUIRE_EQ(result.errors[0].location.begin.line, 9); + } +} + +TEST_CASE_FIXTURE(Fixture, "not_a_function") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local oops = 3 + local stub = oops<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, InstantiateGenericsOnNonFunction); + + REQUIRE_EQ(toString(result.errors[0]), "Cannot instantiate type parameters on something without type parameters."); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_call") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local t = setmetatable({}, { + __call = function(self): T + return nil :: any + end, + }) + + t<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, InstantiateGenericsOnNonFunction); + REQUIRE_EQ(toString(result.errors[0]), "Luau does not currently support explicitly instantiating a table with a `__call` metamethod. \ + You may be able to work around this by creating a function that calls the table, and using that instead."); + } +} + +TEST_CASE_FIXTURE(Fixture, "method_call_incomplete") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local t = { + f = function(self: any): T | U + return nil :: any + end, + } + + local correct: number | string = t:f<>() + local incorrect: number | string = t:f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeMismatch); + REQUIRE_EQ(result.errors[0].location.begin.line, 9); + } +} + +TEST_CASE_FIXTURE(Fixture, "too_many_provided") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f() end + + f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeInstantiationCountMismatch); + + if (FFlag::LuauSolverV2) + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to 'f', which is typed as (...any) -> (). Expected at most 1 type parameter, but 2 provided."); + } + else + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to 'f', which is typed as () -> (). Expected at most 1 type parameter, but 2 provided."); + } + } +} + +TEST_CASE_FIXTURE(Fixture, "too_many_provided_type_packs") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local function f(): (T...) end + + f<<(string, number), (true, false)>>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeInstantiationCountMismatch); + + if (FFlag::LuauSolverV2) + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to 'f', which is typed as (...any) -> (T...). Expected at most 1 type pack, but 2 provided."); + } + else + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to 'f', which is typed as () -> (T...). Expected at most 1 type pack, but 2 provided."); + } + } +} + +TEST_CASE_FIXTURE(Fixture, "too_many_provided_method") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local t = { + f = function(self: any) end, + } + + t:f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeInstantiationCountMismatch); + REQUIRE_EQ(result.errors[0].location.begin.line, 6); + + if (FFlag::LuauSolverV2) + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to function typed as (any) -> (). Expected at most 1 type parameter, but 2 provided."); + } + else + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to 't.f', which is typed as (any) -> (). Expected at most 1 type parameter, but 2 provided."); + } + } +} + +TEST_CASE_FIXTURE(Fixture, "too_many_type_packs_provided_method") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local t = { + f = function(self: any): (T...) end, + } + + t:f<<(number, string), (true, false)>>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeInstantiationCountMismatch); + REQUIRE_EQ(result.errors[0].location.begin.line, 6); + + if (FFlag::LuauSolverV2) + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to function typed as (any) -> (T...). Expected at most 1 type pack, but 2 provided."); + } + else + { + REQUIRE_EQ(toString(result.errors[0]), "Too many type parameters passed to 't.f', which is typed as (any) -> (T...). Expected at most 1 type pack, but 2 provided."); + } + } +} + +TEST_CASE_FIXTURE(Fixture, "function_intersections") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + --!strict + local f: ((T) -> T) & ((T?) -> T) = nil :: any + f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, InstantiateGenericsOnNonFunction); + REQUIRE_EQ(result.errors[0].location.begin.line, 3); + REQUIRE_EQ(toString(result.errors[0]), "Luau does not currently support explicitly instantiating an overloaded function type."); + } +} + +TEST_CASE_FIXTURE(Fixture, "incomplete_type_packs") +{ + SUBCASE_BOTH_SOLVERS() + { + ScopedFastFlag sff{FFlag::LuauExplicitTypeExpressionInstantiation, true}; + + CheckResult result = check(R"( + local f: () -> (A, T...) = nil :: any + local correct: string, b: number, c: boolean = f<>() + local incorrect: number, b: number, c: boolean = f<>() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR(result, TypeMismatch); + REQUIRE_EQ(result.errors[0].location.begin.line, 3); + } +} + +TEST_SUITE_END(); diff --git a/tests/conformance/explicit_type_instantiations.luau b/tests/conformance/explicit_type_instantiations.luau new file mode 100644 index 00000000..1e91118a --- /dev/null +++ b/tests/conformance/explicit_type_instantiations.luau @@ -0,0 +1,38 @@ +-- Tests to ensure explicit type instantiations don't change runtime behavior +local function identity(x: T): T + return x +end + +assert(identity<>(1) == 1) + +local function multipleReturns(x: T): (T, T) + return x, x +end + +local a, b = multipleReturns<>(1) +assert(a == 1 and b == 1) + +local function typePacks(...: T...): T... + return ... +end + +local a, b = typePacks<<(string, number)>>(1, "a") +assert(a == 1 and b == "a") + +local t = {} +function t:method(x: T): T + assert(self == t) + return x +end + +assert(t:method(1) == 1) + +function t:methodTypePacks(...: T...): T... + assert(self == t) + return ... +end + +local a, b = t:methodTypePacks<<(string, number)>>(1, "a") +assert(a == 1 and b == "a") + +return "OK" diff --git a/tests/conformance/math.luau b/tests/conformance/math.luau index 586023ed..7dec3682 100644 --- a/tests/conformance/math.luau +++ b/tests/conformance/math.luau @@ -421,6 +421,19 @@ assert(math.lerp(sq2, sq2, sq2 / 2) == sq2) -- consistent (fails for a*t + b*(1- assert(tostring(math.pow(-2, 0.5)) == "nan") +-- isnan, isinf, isfinite +assert(math.isnan(0/0)) +assert(math.isnan(10) == false) +assert(math.isnan(math.huge) == false) +assert(math.isinf(math.huge)) +assert(math.isinf(-math.huge)) +assert(math.isinf(10) == false) +assert(math.isinf(0/0) == false) +assert(math.isfinite(math.huge) == false) +assert(math.isfinite(-math.huge) == false) +assert(math.isfinite(0/0) == false) +assert(math.isfinite(123.45)) + -- test that fastcalls return correct number of results assert(select('#', math.floor(1.4)) == 1) assert(select('#', math.ceil(1.6)) == 1) @@ -482,5 +495,8 @@ assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) assert(math.lerp("1", "5", 0.5) == 3) +assert(math.isnan("123.45") == false) +assert(math.isinf("123.45") == false) +assert(math.isfinite("123.45") == true) return('OK') diff --git a/tests/require/config_tests/with_config/chained_aliases/.luaurc b/tests/require/config_tests/with_config/chained_aliases/.luaurc index 42e61fcd..736bab48 100644 --- a/tests/require/config_tests/with_config/chained_aliases/.luaurc +++ b/tests/require/config_tests/with_config/chained_aliases/.luaurc @@ -4,6 +4,7 @@ "cyclicentry": "@cyclic1", "cyclic1": "@cyclic2", "cyclic2": "@cyclic3", - "cyclic3": "@cyclic1" + "cyclic3": "@cyclic1", + "dependoninner": "@passthroughinner" } } diff --git a/tests/require/config_tests/with_config/chained_aliases/subdirectory/failing_requirer_inner_dependency.luau b/tests/require/config_tests/with_config/chained_aliases/subdirectory/failing_requirer_inner_dependency.luau new file mode 100644 index 00000000..46bce858 --- /dev/null +++ b/tests/require/config_tests/with_config/chained_aliases/subdirectory/failing_requirer_inner_dependency.luau @@ -0,0 +1 @@ +return require("@dependoninner")