diff --git a/include/swift/AST/ConstTypeInfo.h b/include/swift/AST/ConstTypeInfo.h index 12c66f2c29071..fca370e7a79d4 100644 --- a/include/swift/AST/ConstTypeInfo.h +++ b/include/swift/AST/ConstTypeInfo.h @@ -39,6 +39,7 @@ class CompileTimeValue { Type, KeyPath, FunctionCall, + StaticFunctionCall, MemberReference, InterpolatedString, Runtime @@ -392,6 +393,30 @@ class FunctionCallValue : public CompileTimeValue { std::optional> Parameters; }; +/// A static function reference representation such as +/// let foo = MyStruct.bar(item: "") +/// let foo = MyStruct.bar() +class StaticFunctionCallValue : public CompileTimeValue { +public: + StaticFunctionCallValue(std::string Label, swift::Type Type, + std::vector Parameters) + : CompileTimeValue(ValueKind::StaticFunctionCall), Label(Label), + Type(Type), Parameters(Parameters) {} + + static bool classof(const CompileTimeValue *T) { + return T->getKind() == ValueKind::StaticFunctionCall; + } + + std::string getLabel() const { return Label; } + swift::Type getType() const { return Type; } + std::vector getParameters() const { return Parameters; } + +private: + std::string Label; + swift::Type Type; + std::vector Parameters; +}; + /// A member reference representation such as /// let foo = MyStruct.bar class MemberReferenceValue : public CompileTimeValue { diff --git a/lib/ConstExtract/ConstExtract.cpp b/lib/ConstExtract/ConstExtract.cpp index 73a9e3b4cdb6c..cb08f19c33587 100644 --- a/lib/ConstExtract/ConstExtract.cpp +++ b/lib/ConstExtract/ConstExtract.cpp @@ -294,12 +294,12 @@ static std::shared_ptr extractCompileTimeValue(Expr *expr) { if (functionKind == ExprKind::DeclRef) { auto declRefExpr = cast(callExpr->getFn()); - auto caseName = + auto identifier = declRefExpr->getDecl()->getName().getBaseIdentifier().str().str(); std::vector parameters = extractFunctionArguments(callExpr->getArgs()); - return std::make_shared(caseName, parameters); + return std::make_shared(identifier, parameters); } if (functionKind == ExprKind::ConstructorRefCall) { @@ -313,12 +313,33 @@ static std::shared_ptr extractCompileTimeValue(Expr *expr) { auto fn = dotSyntaxCallExpr->getFn(); if (fn->getKind() == ExprKind::DeclRef) { auto declRefExpr = cast(fn); - auto caseName = + auto baseIdentifierName = declRefExpr->getDecl()->getName().getBaseIdentifier().str().str(); std::vector parameters = extractFunctionArguments(callExpr->getArgs()); - return std::make_shared(caseName, parameters); + + auto declRef = dotSyntaxCallExpr->getFn()->getReferencedDecl(); + switch (declRef.getDecl()->getKind()) { + case DeclKind::EnumElement: { + return std::make_shared(baseIdentifierName, parameters); + } + + case DeclKind::Func: { + auto identifier = declRefExpr->getDecl() + ->getName() + .getBaseIdentifier() + .str() + .str(); + + return std::make_shared( + identifier, callExpr->getType(), parameters); + } + + default: { + break; + } + } } } @@ -836,6 +857,27 @@ void writeValue(llvm::json::OStream &JSON, break; } + case CompileTimeValue::ValueKind::StaticFunctionCall: { + auto staticFunctionCallValue = cast(value); + + JSON.attribute("valueKind", "StaticFunctionCall"); + JSON.attributeObject("value", [&]() { + JSON.attribute("type", toFullyQualifiedTypeNameString( + staticFunctionCallValue->getType())); + JSON.attribute("memberLabel", staticFunctionCallValue->getLabel()); + JSON.attributeArray("arguments", [&] { + for (auto FP : staticFunctionCallValue->getParameters()) { + JSON.object([&] { + JSON.attribute("label", FP.Label); + JSON.attribute("type", toFullyQualifiedTypeNameString(FP.Type)); + writeValue(JSON, FP.Value); + }); + } + }); + }); + break; + } + case CompileTimeValue::ValueKind::MemberReference: { auto memberReferenceValue = cast(value); JSON.attribute("valueKind", "MemberReference"); @@ -846,6 +888,7 @@ void writeValue(llvm::json::OStream &JSON, }); break; } + case CompileTimeValue::ValueKind::InterpolatedString: { auto interpolatedStringValue = cast(value); JSON.attribute("valueKind", "InterpolatedStringLiteral"); @@ -1049,14 +1092,11 @@ createBuilderCompileTimeValue(CustomAttr *AttachedResultBuilder, void writeSingleBuilderMemberElement( llvm::json::OStream &JSON, std::shared_ptr Element) { switch (Element.get()->getKind()) { - case CompileTimeValue::ValueKind::Enum: { - auto enumValue = cast(Element.get()); - if (enumValue->getIdentifier() == "buildExpression") { - if (enumValue->getParameters().has_value()) { - auto params = enumValue->getParameters().value(); - for (auto FP : params) { - writeValue(JSON, FP.Value); - } + case CompileTimeValue::ValueKind::StaticFunctionCall: { + auto staticFunctionCallValue = cast(Element.get()); + if (staticFunctionCallValue->getLabel() == "buildExpression") { + for (auto FP : staticFunctionCallValue->getParameters()) { + writeValue(JSON, FP.Value); } } break; diff --git a/test/ConstExtraction/ExtractCalls.swift b/test/ConstExtraction/ExtractCalls.swift index c02b98289118c..9eae9c95f843f 100644 --- a/test/ConstExtraction/ExtractCalls.swift +++ b/test/ConstExtraction/ExtractCalls.swift @@ -115,8 +115,8 @@ public struct Bat { // CHECK-NEXT: "name": "adder", // CHECK-NEXT: "arguments": [ // CHECK-NEXT: { -// CHECK-NEXT: "label": "", -// CHECK-NEXT: "type": "Swift.Int", +// CHECK-NEXT: "label": "", +// CHECK-NEXT: "type": "Swift.Int", // CHECK-NEXT: "valueKind": "RawLiteral", // CHECK-NEXT: "value": "2" // CHECK-NEXT: }, diff --git a/test/ConstExtraction/ExtractStaticFunctions.swift b/test/ConstExtraction/ExtractStaticFunctions.swift new file mode 100644 index 0000000000000..d58d761e43404 --- /dev/null +++ b/test/ConstExtraction/ExtractStaticFunctions.swift @@ -0,0 +1,85 @@ +// RUN: %empty-directory(%t) +// RUN: echo "[MyProto]" > %t/protocols.json + +// RUN: %target-swift-frontend -typecheck -emit-const-values-path %t/ExtractStaticFunctions.swiftconstvalues -const-gather-protocols-file %t/protocols.json -primary-file %s +// RUN: cat %t/ExtractStaticFunctions.swiftconstvalues 2>&1 | %FileCheck %s + +protocol MyProto {} + +enum Bar { + case one + case two(item: String) +} + +struct Baz { + static var one: Baz { + Baz() + } + + static func two(item: String) -> Baz { + return Baz() + } + + static func three() -> Baz { + return Baz() + } +} + +struct Statics: MyProto { + var bar1 = Bar.one + var bar2 = Bar.two(item: "bar") + var baz1 = Baz.one + var baz2 = Baz.two(item: "baz") + var baz3 = Baz.three() +} + +// CHECK: "label": "bar1", +// CHECK-NEXT: "type": "ExtractStaticFunctions.Bar", +// CHECK: "valueKind": "Enum", +// CHECK-NEXT: "value": { +// CHECK-NEXT: "name": "one" +// CHECK-NEXT: } +// CHECK: "label": "bar2", +// CHECK-NEXT: "type": "ExtractStaticFunctions.Bar", +// CHECK: "valueKind": "Enum", +// CHECK-NEXT: "value": { +// CHECK-NEXT: "name": "two", +// CHECK-NEXT: "arguments": [ +// CHECK-NEXT: { +// CHECK-NEXT: "label": "item", +// CHECK-NEXT: "type": "Swift.String", +// CHECK-NEXT: "valueKind": "RawLiteral", +// CHECK-NEXT: "value": "bar" +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK: "label": "baz1", +// CHECK-NEXT: "type": "ExtractStaticFunctions.Baz", +// CHECK: "valueKind": "MemberReference" +// CHECK-NEXT: "value": { +// CHECK-NEXT: "baseType": "ExtractStaticFunctions.Baz", +// CHECK-NEXT: "memberLabel": "one" +// CHECK-NEXT: } +// CHECK: "label": "baz2", +// CHECK-NEXT: "type": "ExtractStaticFunctions.Baz", +// CHECK: "valueKind": "StaticFunctionCall", +// CHECK-NEXT: "value": { +// CHECK-NEXT: "type": "ExtractStaticFunctions.Baz", +// CHECK-NEXT: "memberLabel": "two", +// CHECK-NEXT: "arguments": [ +// CHECK-NEXT: { +// CHECK-NEXT: "label": "item", +// CHECK-NEXT: "type": "Swift.String", +// CHECK-NEXT: "valueKind": "RawLiteral", +// CHECK-NEXT: "value": "baz" +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK: "label": "baz3", +// CHECK-NEXT: "type": "ExtractStaticFunctions.Baz", +// CHECK: "valueKind": "StaticFunctionCall", +// CHECK-NEXT: "value": { +// CHECK-NEXT: "type": "ExtractStaticFunctions.Baz", +// CHECK-NEXT: "memberLabel": "three", +// CHECK-NEXT: "arguments": [] +// CHECK-NEXT: }