diff --git a/Examples/Sources/ExamplePlugin/ExamplePlugin.swift b/Examples/Sources/ExamplePlugin/ExamplePlugin.swift index c3953f95930..9eab0b9b94b 100644 --- a/Examples/Sources/ExamplePlugin/ExamplePlugin.swift +++ b/Examples/Sources/ExamplePlugin/ExamplePlugin.swift @@ -10,6 +10,7 @@ struct ThePlugin: CompilerPlugin { PeerValueWithSuffixNameMacro.self, MemberDeprecatedMacro.self, EquatableConformanceMacro.self, + SendableExtensionMacro.self, DidSetPrintMacro.self, PrintAnyMacro.self, ] diff --git a/Examples/Sources/ExamplePlugin/Macros.swift b/Examples/Sources/ExamplePlugin/Macros.swift index 234d533d9e3..51f2179338b 100644 --- a/Examples/Sources/ExamplePlugin/Macros.swift +++ b/Examples/Sources/ExamplePlugin/Macros.swift @@ -86,6 +86,26 @@ struct EquatableConformanceMacro: ConformanceMacro { } } +public struct SendableExtensionMacro: ExtensionMacro { + public static func expansion( + of node: AttributeSyntax, + attachedTo: some DeclGroupSyntax, + providingExtensionsOf type: some TypeSyntaxProtocol, + in context: some MacroExpansionContext + ) throws -> [ExtensionDeclSyntax] { + let sendableExtension: DeclSyntax = + """ + extension \(type.trimmed): Sendable {} + """ + + guard let extensionDecl = sendableExtension.as(ExtensionDeclSyntax.self) else { + return [] + } + + return [extensionDecl] + } +} + /// Add 'didSet' printing the new value. struct DidSetPrintMacro: AccessorMacro { static func expansion( diff --git a/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift b/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift index a59bf726fca..2b9e718b4f7 100644 --- a/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift +++ b/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift @@ -122,14 +122,15 @@ extension CompilerPluginMessageHandler { expandingSyntax: expandingSyntax ) - case .expandAttachedMacro(let macro, let macroRole, let discriminator, let attributeSyntax, let declSyntax, let parentDeclSyntax): + case .expandAttachedMacro(let macro, let macroRole, let discriminator, let attributeSyntax, let declSyntax, let parentDeclSyntax, let extendedTypeSyntax): try expandAttachedMacro( macro: macro, macroRole: macroRole, discriminator: discriminator, attributeSyntax: attributeSyntax, declSyntax: declSyntax, - parentDeclSyntax: parentDeclSyntax + parentDeclSyntax: parentDeclSyntax, + extendedTypeSyntax: extendedTypeSyntax ) case .loadPluginLibrary(let libraryPath, let moduleName): diff --git a/Sources/SwiftCompilerPluginMessageHandling/Macros.swift b/Sources/SwiftCompilerPluginMessageHandling/Macros.swift index 5923744bc96..f181d846c58 100644 --- a/Sources/SwiftCompilerPluginMessageHandling/Macros.swift +++ b/Sources/SwiftCompilerPluginMessageHandling/Macros.swift @@ -86,7 +86,8 @@ extension CompilerPluginMessageHandler { discriminator: String, attributeSyntax: PluginMessage.Syntax, declSyntax: PluginMessage.Syntax, - parentDeclSyntax: PluginMessage.Syntax? + parentDeclSyntax: PluginMessage.Syntax?, + extendedTypeSyntax: PluginMessage.Syntax? ) throws { let sourceManager = SourceManager() let context = PluginMacroExpansionContext( @@ -100,6 +101,9 @@ extension CompilerPluginMessageHandler { ).cast(AttributeSyntax.self) let declarationNode = sourceManager.add(declSyntax).cast(DeclSyntax.self) let parentDeclNode = parentDeclSyntax.map { sourceManager.add($0).cast(DeclSyntax.self) } + let extendedType = extendedTypeSyntax.map { + sourceManager.add($0).cast(TypeSyntax.self) + } // TODO: Make this a 'String?' and remove non-'hasExpandMacroResult' branches. let expandedSources: [String]? @@ -115,6 +119,7 @@ extension CompilerPluginMessageHandler { attributeNode: attributeNode, declarationNode: declarationNode, parentDeclNode: parentDeclNode, + extendedType: extendedType, in: context ) if let expansions, hostCapability.hasExpandMacroResult { @@ -159,6 +164,7 @@ private extension MacroRole { case .peer: self = .peer case .conformance: self = .conformance case .codeItem: self = .codeItem + case .extension: self = .extension } } } diff --git a/Sources/SwiftCompilerPluginMessageHandling/PluginMessages.swift b/Sources/SwiftCompilerPluginMessageHandling/PluginMessages.swift index 2075240488a..0b01d495604 100644 --- a/Sources/SwiftCompilerPluginMessageHandling/PluginMessages.swift +++ b/Sources/SwiftCompilerPluginMessageHandling/PluginMessages.swift @@ -34,7 +34,8 @@ internal enum HostToPluginMessage: Codable { discriminator: String, attributeSyntax: PluginMessage.Syntax, declSyntax: PluginMessage.Syntax, - parentDeclSyntax: PluginMessage.Syntax? + parentDeclSyntax: PluginMessage.Syntax?, + extendedTypeSyntax: PluginMessage.Syntax? ) /// Optionally implemented message to load a dynamic link library. @@ -76,7 +77,7 @@ internal enum PluginToHostMessage: Codable { } /*namespace*/ internal enum PluginMessage { - static var PROTOCOL_VERSION_NUMBER: Int { 5 } // Added 'expandMacroResult'. + static var PROTOCOL_VERSION_NUMBER: Int { 6 } // Added 'expandMacroResult'. struct HostCapability: Codable { var protocolVersion: Int @@ -107,6 +108,7 @@ internal enum PluginToHostMessage: Codable { case peer case conformance case codeItem + case `extension` } struct SourceLocation: Codable { diff --git a/Sources/SwiftParser/Attributes.swift b/Sources/SwiftParser/Attributes.swift index 8fb100b8849..e4d793797a9 100644 --- a/Sources/SwiftParser/Attributes.swift +++ b/Sources/SwiftParser/Attributes.swift @@ -332,14 +332,44 @@ extension Parser { ) ) case nil: + let isAttached = self.peek().isAttachedKeyword return parseAttribute(argumentMode: .customAttribute) { parser in - let arguments = parser.parseArgumentListElements(pattern: .none) + let arguments: [RawTupleExprElementSyntax] + if isAttached { + arguments = parser.parseAttachedArguments() + } else { + arguments = parser.parseArgumentListElements(pattern: .none) + } + return .argumentList(RawTupleExprElementListSyntax(elements: arguments, arena: parser.arena)) } } } } +extension Parser { + mutating func parseAttachedArguments() -> [RawTupleExprElementSyntax] { + let (unexpectedBeforeRole, role) = self.expect(.identifier, TokenSpec(.extension, remapping: .identifier), default: .identifier) + let roleTrailingComma = self.consume(if: .comma) + let roleElement = RawTupleExprElementSyntax( + label: nil, + colon: nil, + expression: RawExprSyntax( + RawIdentifierExprSyntax( + unexpectedBeforeRole, + identifier: role, + declNameArguments: nil, + arena: self.arena + ) + ), + trailingComma: roleTrailingComma, + arena: self.arena + ) + let additionalArgs = self.parseArgumentListElements(pattern: .none) + return [roleElement] + additionalArgs + } +} + extension Parser { mutating func parseDifferentiableAttribute() -> RawAttributeSyntax { let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign) diff --git a/Sources/SwiftParser/Types.swift b/Sources/SwiftParser/Types.swift index e3453b61f22..346f1fa2723 100644 --- a/Sources/SwiftParser/Types.swift +++ b/Sources/SwiftParser/Types.swift @@ -1076,6 +1076,10 @@ extension Lexer.Lexeme { || self.rawTokenKind == .prefixOperator } + var isAttachedKeyword: Bool { + return self.rawTokenKind == .identifier && self.tokenText == "attached" + } + var isEllipsis: Bool { return self.isAnyOperator && self.tokenText == "..." } diff --git a/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift b/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift index 08a4671eace..a0226454289 100644 --- a/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift +++ b/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift @@ -22,6 +22,7 @@ public enum MacroRole { case peer case conformance case codeItem + case `extension` } extension MacroRole { @@ -35,6 +36,7 @@ extension MacroRole { case .peer: return "PeerMacro" case .conformance: return "ConformanceMacro" case .codeItem: return "CodeItemMacro" + case .extension: return "ExtensionMacro" } } } @@ -45,6 +47,7 @@ private enum MacroExpansionError: Error, CustomStringConvertible { case parentDeclGroupNil case declarationNotDeclGroup case declarationNotIdentified + case noExtendedTypeSyntax case noFreestandingMacroRoles(Macro.Type) var description: String { @@ -61,6 +64,9 @@ private enum MacroExpansionError: Error, CustomStringConvertible { case .declarationNotIdentified: return "declaration is not a 'Identified' syntax" + case .noExtendedTypeSyntax: + return "no extended type for extension macro" + case .noFreestandingMacroRoles(let type): return "macro implementation type '\(type)' does not conform to any freestanding macro protocol" @@ -113,7 +119,7 @@ public func expandFreestandingMacro( let rewritten = try codeItemMacroDef.expansion(of: node, in: context) expandedSyntax = Syntax(CodeBlockItemListSyntax(rewritten)) - case (.accessor, _), (.memberAttribute, _), (.member, _), (.peer, _), (.conformance, _), (.expression, _), (.declaration, _), + case (.accessor, _), (.memberAttribute, _), (.member, _), (.peer, _), (.conformance, _), (.extension, _), (.expression, _), (.declaration, _), (.codeItem, _): throw MacroExpansionError.unmatchedMacroRole(definition, macroRole) } @@ -178,6 +184,7 @@ public func expandAttachedMacroWithoutCollapsing attributeNode: AttributeSyntax, declarationNode: DeclSyntax, parentDeclNode: DeclSyntax?, + extendedType: TypeSyntax?, in context: Context ) -> [String]? { do { @@ -295,6 +302,39 @@ public func expandAttachedMacroWithoutCollapsing return "extension \(typeName) : \(protocolName) \(whereClause) {}" } + case (let attachedMacro as ExtensionMacro.Type, .extension): + guard let declGroup = declarationNode.asProtocol(DeclGroupSyntax.self) else { + // Compiler error: type mismatch. + throw MacroExpansionError.declarationNotDeclGroup + } + + guard let extendedType = extendedType else { + throw MacroExpansionError.noExtendedTypeSyntax + } + + // Local function to expand an extension macro once we've opened up + // the existential. + func expandExtensionMacro( + _ node: some DeclGroupSyntax + ) throws -> [ExtensionDeclSyntax] { + return try attachedMacro.expansion( + of: attributeNode, + attachedTo: node, + providingExtensionsOf: extendedType, + in: context + ) + } + + let extensions = try _openExistential( + declGroup, + do: expandExtensionMacro + ) + + // Form a buffer of peer declarations to return to the caller. + return extensions.map { + $0.formattedExpansion(definition.formatMode) + } + default: throw MacroExpansionError.unmatchedMacroRole(definition, macroRole) } @@ -323,6 +363,7 @@ public func expandAttachedMacro( attributeNode: AttributeSyntax, declarationNode: DeclSyntax, parentDeclNode: DeclSyntax?, + extendedType: TypeSyntax?, in context: Context ) -> String? { let expandedSources = expandAttachedMacroWithoutCollapsing( @@ -331,6 +372,7 @@ public func expandAttachedMacro( attributeNode: attributeNode, declarationNode: declarationNode, parentDeclNode: parentDeclNode, + extendedType: extendedType, in: context ) return expandedSources.map { diff --git a/Sources/SwiftSyntaxMacros/CMakeLists.txt b/Sources/SwiftSyntaxMacros/CMakeLists.txt index e3380d556ec..97cdca91ae4 100644 --- a/Sources/SwiftSyntaxMacros/CMakeLists.txt +++ b/Sources/SwiftSyntaxMacros/CMakeLists.txt @@ -13,6 +13,7 @@ add_swift_host_library(SwiftSyntaxMacros MacroProtocols/ConformanceMacro.swift MacroProtocols/DeclarationMacro.swift MacroProtocols/ExpressionMacro.swift + MacroProtocols/ExtensionMacro.swift MacroProtocols/FreestandingMacro.swift MacroProtocols/Macro.swift MacroProtocols/Macro+Format.swift diff --git a/Sources/SwiftSyntaxMacros/MacroProtocols/ExtensionMacro.swift b/Sources/SwiftSyntaxMacros/MacroProtocols/ExtensionMacro.swift new file mode 100644 index 00000000000..8b3c8bd9b39 --- /dev/null +++ b/Sources/SwiftSyntaxMacros/MacroProtocols/ExtensionMacro.swift @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +import SwiftSyntax + +/// Describes a macro that can add extensions to the declaration it's +/// attached to. +public protocol ExtensionMacro: AttachedMacro { + /// Expand an attached extension macro to produce a set of extensions. + /// + /// - Parameters: + /// - node: The custom attribute describing the attached macro. + /// - declaration: The declaration the macro attribute is attached to. + /// - type: The type to provide extensions of. + /// - context: The context in which to perform the macro expansion. + /// + /// - Returns: the set of extension declarations introduced by the macro, + /// which are always inserted at top-level scope. Each extension must extend + /// the `type` parameter. + static func expansion( + of node: AttributeSyntax, + attachedTo declaration: some DeclGroupSyntax, + providingExtensionsOf type: some TypeSyntaxProtocol, + in context: some MacroExpansionContext + ) throws -> [ExtensionDeclSyntax] +} diff --git a/Sources/SwiftSyntaxMacros/MacroSystem.swift b/Sources/SwiftSyntaxMacros/MacroSystem.swift index 5d88efc2fb8..3f7ba87f8c7 100644 --- a/Sources/SwiftSyntaxMacros/MacroSystem.swift +++ b/Sources/SwiftSyntaxMacros/MacroSystem.swift @@ -115,7 +115,8 @@ class MacroApplication: SyntaxRewriter { || macro is MemberMacro.Type || macro is AccessorMacro.Type || macro is MemberAttributeMacro.Type - || macro is ConformanceMacro.Type) + || macro is ConformanceMacro.Type + || macro is ExtensionMacro.Type) } if newAttributes.isEmpty { @@ -438,6 +439,23 @@ extension MacroApplication { } } + let extensionMacroAttrs = getMacroAttributes(attachedTo: decl.as(DeclSyntax.self)!, ofType: ExtensionMacro.Type.self) + let extendedTypeSyntax = TypeSyntax("\(extendedType.trimmed)") + for (attribute, extensionMacro) in extensionMacroAttrs { + do { + let newExtensions = try extensionMacro.expansion( + of: attribute, + attachedTo: decl, + providingExtensionsOf: extendedTypeSyntax, + in: context + ) + + extensions.append(contentsOf: newExtensions.map(DeclSyntax.init)) + } catch { + context.addDiagnostics(from: error, node: attribute) + } + } + return extensions } diff --git a/Tests/SwiftParserTest/AttributeTests.swift b/Tests/SwiftParserTest/AttributeTests.swift index 6295daa112c..8e09c1c5464 100644 --- a/Tests/SwiftParserTest/AttributeTests.swift +++ b/Tests/SwiftParserTest/AttributeTests.swift @@ -629,4 +629,20 @@ final class AttributeTests: XCTestCase { """ ) } + + func testAttachedExtensionAttribute() { + assertParse( + """ + @attached(extension) + macro m() + """ + ) + + assertParse( + """ + @attached(extension, names: named(test)) + macro m() + """ + ) + } } diff --git a/Tests/SwiftSyntaxMacrosTest/MacroSystemTests.swift b/Tests/SwiftSyntaxMacrosTest/MacroSystemTests.swift index 0a467e8f534..dda1cf940ca 100644 --- a/Tests/SwiftSyntaxMacrosTest/MacroSystemTests.swift +++ b/Tests/SwiftSyntaxMacrosTest/MacroSystemTests.swift @@ -682,6 +682,26 @@ public struct SendableConformanceMacro: ConformanceMacro { } } +public struct SendableExtensionMacro: ExtensionMacro { + public static func expansion( + of node: AttributeSyntax, + attachedTo: some DeclGroupSyntax, + providingExtensionsOf type: some TypeSyntaxProtocol, + in context: some MacroExpansionContext + ) throws -> [ExtensionDeclSyntax] { + let sendableExtension: DeclSyntax = + """ + extension \(type.trimmed): Sendable {} + """ + + guard let extensionDecl = sendableExtension.as(ExtensionDeclSyntax.self) else { + return [] + } + + return [extensionDecl] + } +} + public struct DeclsFromStringsMacroNoAttrs: DeclarationMacro { public static var propagateFreestandingMacroAttributes: Bool { false } public static var propagateFreestandingMacroModifiers: Bool { false } @@ -726,6 +746,7 @@ public let testMacros: [String: Macro.Type] = [ "customTypeWrapper": CustomTypeWrapperMacro.self, "unwrap": UnwrapMacro.self, "AddSendable": SendableConformanceMacro.self, + "AddSendableExtension": SendableExtensionMacro.self, ] final class MacroSystemTests: XCTestCase { @@ -1196,4 +1217,23 @@ final class MacroSystemTests: XCTestCase { indentationWidth: indentationWidth ) } + + func testExtensionExpansion() { + assertMacroExpansion( + """ + @AddSendableExtension + struct MyType { + } + """, + expandedSource: """ + + struct MyType { + } + extension MyType: Sendable { + } + """, + macros: testMacros, + indentationWidth: indentationWidth + ) + } } diff --git a/lit_tests/compiler_plugin_basic.swift b/lit_tests/compiler_plugin_basic.swift index bb7c4200793..4799c4c7dba 100644 --- a/lit_tests/compiler_plugin_basic.swift +++ b/lit_tests/compiler_plugin_basic.swift @@ -4,6 +4,7 @@ // // RUN: %swift-frontend -typecheck -verify -swift-version 5 \ // RUN: -enable-experimental-feature CodeItemMacros \ +// RUN: -enable-experimental-feature ExtensionMacros \ // RUN: -dump-macro-expansions \ // RUN: -load-plugin-executable %examples_bin_path/ExamplePlugin#ExamplePlugin \ // RUN: -parse-as-library \