diff --git a/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift b/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift index 32123e1de0aaf..3a8e6da18dab6 100644 --- a/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift +++ b/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift @@ -1237,11 +1237,10 @@ func parseCxxSpansInSignature( } func parseMacroParam( - _ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax, _ rewriter: CountExprRewriter, + _ paramExpr: ExprSyntax, _ signature: FunctionSignatureSyntax, _ rewriter: CountExprRewriter, nonescapingPointers: inout Set, lifetimeDependencies: inout [SwiftifyExpr: [LifetimeDependence]] ) throws -> ParamInfo? { - let paramExpr = paramAST.expression guard let enumConstructorExpr = paramExpr.as(FunctionCallExprSyntax.self) else { throw DiagnosticError( "expected _SwiftifyInfo enum literal as argument, got '\(paramExpr)'", node: paramExpr) @@ -1567,6 +1566,121 @@ func deconstructFunction(_ declaration: some DeclSyntaxProtocol) throws -> Funct throw DiagnosticError("@_SwiftifyImport only works on functions and initializers", node: declaration) } +func constructOverloadFunction(forDecl declaration: some DeclSyntaxProtocol, leadingTrivia: Trivia, + args arguments: [ExprSyntax], spanAvailability: String?, + typeMappings: [String: String]?) throws -> DeclSyntax { + let origFuncComponents = try deconstructFunction(declaration) + let (funcComponents, rewriter) = renameParameterNamesIfNeeded(origFuncComponents) + + var nonescapingPointers = Set() + var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:] + var parsedArgs = try arguments.compactMap { + try parseMacroParam( + $0, funcComponents.signature, rewriter, nonescapingPointers: &nonescapingPointers, + lifetimeDependencies: &lifetimeDependencies) + } + parsedArgs.append( + contentsOf: try parseCxxSpansInSignature(funcComponents.signature, typeMappings)) + setNonescapingPointers(&parsedArgs, nonescapingPointers) + setLifetimeDependencies(&parsedArgs, lifetimeDependencies) + // We only transform non-escaping spans. + parsedArgs = parsedArgs.filter { + if let cxxSpanArg = $0 as? CxxSpan { + return cxxSpanArg.nonescaping || cxxSpanArg.pointerIndex == .return + } else { + return true + } + } + try checkArgs(parsedArgs, funcComponents) + parsedArgs.sort { a, b in + // make sure return value cast to Span happens last so that withUnsafeBufferPointer + // doesn't return a ~Escapable type + if a.pointerIndex != .return && b.pointerIndex == .return { + return true + } + if a.pointerIndex == .return && b.pointerIndex != .return { + return false + } + return paramOrReturnIndex(a.pointerIndex) < paramOrReturnIndex(b.pointerIndex) + } + let baseBuilder = FunctionCallBuilder(funcComponents) + + let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce( + baseBuilder, + { (prev, parsedArg) in + parsedArg.getBoundsCheckedThunkBuilder(prev, funcComponents) + }) + let newSignature = try builder.buildFunctionSignature([:], nil) + var eliminatedArgs = Set() + let basicChecks = try builder.buildBasicBoundsChecks(&eliminatedArgs) + let compoundChecks = try builder.buildCompoundBoundsChecks() + let checks = (basicChecks + compoundChecks).map { e in + CodeBlockItemSyntax(leadingTrivia: "\n", item: e) + } + let call: CodeBlockItemSyntax = + if declaration.is(InitializerDeclSyntax.self) { + CodeBlockItemSyntax( + item: CodeBlockItemSyntax.Item( + try builder.buildFunctionCall([:]))) + } else { + CodeBlockItemSyntax( + item: CodeBlockItemSyntax.Item( + ReturnStmtSyntax( + returnKeyword: .keyword(.return, trailingTrivia: " "), + expression: try builder.buildFunctionCall([:])))) + } + let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call])) + let returnLifetimeAttribute = getReturnLifetimeAttribute(funcComponents, lifetimeDependencies) + let lifetimeAttrs = + returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcComponents.attributes) + let availabilityAttr = try getAvailability(newSignature, spanAvailability) + let disfavoredOverload: [AttributeListSyntax.Element] = + [ + .attribute( + AttributeSyntax( + atSign: .atSignToken(), + attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload"))) + ] + let attributes = + funcComponents.attributes.filter { e in + switch e { + case .attribute(let attr): + // don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient + let name = attr.attributeName.as(IdentifierTypeSyntax.self)?.name.text + return name == nil || (name != "_SwiftifyImport" && name != "_alwaysEmitIntoClient") + default: return true + } + } + [ + .attribute( + AttributeSyntax( + atSign: .atSignToken(), + attributeName: IdentifierTypeSyntax(name: "_alwaysEmitIntoClient"))) + ] + + availabilityAttr + + lifetimeAttrs + + disfavoredOverload + let trivia = + leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n") + if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) { + return DeclSyntax( + origFuncDecl + .with(\.signature, newSignature) + .with(\.body, body) + .with(\.attributes, AttributeListSyntax(attributes)) + .with(\.leadingTrivia, trivia)) + } + if let origInitDecl = declaration.as(InitializerDeclSyntax.self) { + return DeclSyntax( + origInitDecl + .with(\.signature, newSignature) + .with(\.body, body) + .with(\.attributes, AttributeListSyntax(attributes)) + .with(\.leadingTrivia, trivia)) + } + throw DiagnosticError( + "Expected function decl or initializer decl, found: \(declaration.kind)", node: declaration) +} + /// A macro that adds safe(r) wrappers for functions with unsafe pointer types. /// Depends on bounds, escapability and lifetime information for each pointer. /// Intended to map to C attributes like __counted_by, __ended_by and __no_escape, @@ -1580,9 +1694,6 @@ public struct SwiftifyImportMacro: PeerMacro { in context: some MacroExpansionContext ) throws -> [DeclSyntax] { do { - let origFuncComponents = try deconstructFunction(declaration) - let (funcComponents, rewriter) = renameParameterNamesIfNeeded(origFuncComponents) - let argumentList = node.arguments!.as(LabeledExprListSyntax.self)! var arguments = [LabeledExprSyntax](argumentList) let typeMappings = try parseTypeMappingParam(arguments.last) @@ -1593,107 +1704,12 @@ public struct SwiftifyImportMacro: PeerMacro { if spanAvailability != nil { arguments = arguments.dropLast() } - var nonescapingPointers = Set() - var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:] - var parsedArgs = try arguments.compactMap { - try parseMacroParam( - $0, funcComponents.signature, rewriter, nonescapingPointers: &nonescapingPointers, - lifetimeDependencies: &lifetimeDependencies) - } - parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcComponents.signature, typeMappings)) - setNonescapingPointers(&parsedArgs, nonescapingPointers) - setLifetimeDependencies(&parsedArgs, lifetimeDependencies) - // We only transform non-escaping spans. - parsedArgs = parsedArgs.filter { - if let cxxSpanArg = $0 as? CxxSpan { - return cxxSpanArg.nonescaping || cxxSpanArg.pointerIndex == .return - } else { - return true - } - } - try checkArgs(parsedArgs, funcComponents) - parsedArgs.sort { a, b in - // make sure return value cast to Span happens last so that withUnsafeBufferPointer - // doesn't return a ~Escapable type - if a.pointerIndex != .return && b.pointerIndex == .return { - return true - } - if a.pointerIndex == .return && b.pointerIndex != .return { - return false - } - return paramOrReturnIndex(a.pointerIndex) < paramOrReturnIndex(b.pointerIndex) - } - let baseBuilder = FunctionCallBuilder(funcComponents) - - let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce( - baseBuilder, - { (prev, parsedArg) in - parsedArg.getBoundsCheckedThunkBuilder(prev, funcComponents) - }) - let newSignature = try builder.buildFunctionSignature([:], nil) - var eliminatedArgs = Set() - let basicChecks = try builder.buildBasicBoundsChecks(&eliminatedArgs) - let compoundChecks = try builder.buildCompoundBoundsChecks() - let checks = (basicChecks + compoundChecks).map { e in - CodeBlockItemSyntax(leadingTrivia: "\n", item: e) - } - var call : CodeBlockItemSyntax - if declaration.is(InitializerDeclSyntax.self) { - call = CodeBlockItemSyntax( - item: CodeBlockItemSyntax.Item( - try builder.buildFunctionCall([:]))) - } else { - call = CodeBlockItemSyntax( - item: CodeBlockItemSyntax.Item( - ReturnStmtSyntax( - returnKeyword: .keyword(.return, trailingTrivia: " "), - expression: try builder.buildFunctionCall([:])))) - } - let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call])) - let returnLifetimeAttribute = getReturnLifetimeAttribute(funcComponents, lifetimeDependencies) - let lifetimeAttrs = - returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcComponents.attributes) - let availabilityAttr = try getAvailability(newSignature, spanAvailability) - let disfavoredOverload: [AttributeListSyntax.Element] = - [ - .attribute( - AttributeSyntax( - atSign: .atSignToken(), - attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload"))) - ] - let attributes = funcComponents.attributes.filter { e in - switch e { - case .attribute(let attr): - // don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient - let name = attr.attributeName.as(IdentifierTypeSyntax.self)?.name.text - return name == nil || (name != "_SwiftifyImport" && name != "_alwaysEmitIntoClient") - default: return true - } - } + [ - .attribute( - AttributeSyntax( - atSign: .atSignToken(), - attributeName: IdentifierTypeSyntax(name: "_alwaysEmitIntoClient"))) - ] - + availabilityAttr - + lifetimeAttrs - + disfavoredOverload - let trivia = node.leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n") - if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) { - return [DeclSyntax(origFuncDecl - .with(\.signature, newSignature) - .with(\.body, body) - .with(\.attributes, AttributeListSyntax(attributes)) - .with(\.leadingTrivia, trivia))] - } - if let origInitDecl = declaration.as(InitializerDeclSyntax.self) { - return [DeclSyntax(origInitDecl - .with(\.signature, newSignature) - .with(\.body, body) - .with(\.attributes, AttributeListSyntax(attributes)) - .with(\.leadingTrivia, trivia))] - } - return [] + let args = arguments.map { $0.expression } + return [ + try constructOverloadFunction( + forDecl: declaration, leadingTrivia: node.leadingTrivia, args: args, + spanAvailability: spanAvailability, + typeMappings: typeMappings)] } catch let error as DiagnosticError { context.diagnose( Diagnostic(