Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 108 additions & 34 deletions lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,11 @@ protocol ParamInfo: CustomStringConvertible {
) -> BoundsCheckedThunkBuilder
}

func getParamName(_ param: FunctionParameterSyntax, _ paramIndex: Int) -> TokenSyntax {
let name = param.secondName ?? param.firstName
if name.trimmed.text == "_" {
return "_param\(raw: paramIndex)"
}
return name
}

func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> TokenSyntax? {
switch expr {
case .param(let i):
let funcParam = getParam(funcDecl, i - 1)
return getParamName(funcParam, i - 1)
return funcParam.name
case .`self`:
return .keyword(.self)
default: return nil
Expand Down Expand Up @@ -427,12 +419,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
// filter out deleted parameters, i.e. ones where argTypes[i] _contains_ nil
return type == nil || type! != nil
}.map { (i: Int, e: FunctionParameterSyntax) in
let param = e.with(\.type, (argTypes[i] ?? e.type)!)
let name = param.secondName ?? param.firstName
if name.trimmed.text == "_" {
return param.with(\.secondName, getParamName(param, i))
}
return param
e.with(\.type, (argTypes[i] ?? e.type)!)
}
if let last = newParams.popLast() {
newParams.append(last.with(\.trailingComma, nil))
Expand All @@ -450,9 +437,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
let functionRef = DeclReferenceExprSyntax(baseName: base.name)
let args: [ExprSyntax] = base.signature.parameterClause.parameters.enumerated()
.map { (i: Int, param: FunctionParameterSyntax) in
let name = getParamName(param, i)
let declref = DeclReferenceExprSyntax(baseName: name)
return pointerArgs[i] ?? ExprSyntax(declref)
return pointerArgs[i] ?? ExprSyntax("\(param.name)")
}
let labels: [TokenSyntax?] = base.signature.parameterClause.parameters.map { param in
let firstName = param.firstName.trimmed
Expand All @@ -468,7 +453,8 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
comma = .commaToken()
}
let colon: TokenSyntax? = label != nil ? .colonToken() : nil
return LabeledExprSyntax(label: label, colon: colon, expression: arg, trailingComma: comma)
// The compiler emits warnings if you unnecessarily escape labels in function calls
return LabeledExprSyntax(label: label?.withoutBackticks, colon: colon, expression: arg, trailingComma: comma)
}
let call = ExprSyntax(
FunctionCallExprSyntax(
Expand Down Expand Up @@ -510,7 +496,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
args[index] = ExprSyntax("\(raw: typeName)(\(raw: name))")
return try base.buildFunctionCall(args)
} else {
let unwrappedName = TokenSyntax("_\(name)Ptr")
let unwrappedName = TokenSyntax("_\(name.withoutBackticks)Ptr")
args[index] = ExprSyntax("\(raw: typeName)(\(unwrappedName))")
let call = try base.buildFunctionCall(args)

Expand Down Expand Up @@ -663,7 +649,7 @@ extension ParamBoundsThunkBuilder {
}

var name: TokenSyntax {
getParamName(param, index)
param.name
}
}

Expand Down Expand Up @@ -796,7 +782,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
}

func buildUnwrapCall(_ argOverrides: [Int: ExprSyntax]) throws -> ExprSyntax {
let unwrappedName = TokenSyntax("_\(name)Ptr")
let unwrappedName = TokenSyntax("_\(name.withoutBackticks)Ptr").escapeIfNeeded
var args = argOverrides
let argExpr = ExprSyntax("\(unwrappedName).baseAddress")
assert(args[index] == nil)
Expand All @@ -809,7 +795,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
}
}
let call = try base.buildFunctionCall(args)
let ptrRef = unwrapIfNullable(ExprSyntax(DeclReferenceExprSyntax(baseName: name)))
let ptrRef = unwrapIfNullable("\(name)")

let funcName =
switch (isSizedBy, isMutablePointerType(oldType)) {
Expand Down Expand Up @@ -1004,7 +990,7 @@ func parseSwiftifyExpr(_ expr: ExprSyntax) throws -> SwiftifyExpr {
}

func parseCountedByEnum(
_ enumConstructorExpr: FunctionCallExprSyntax, _ signature: FunctionSignatureSyntax
_ enumConstructorExpr: FunctionCallExprSyntax, _ signature: FunctionSignatureSyntax, _ rewriter: CountExprRewriter
) throws -> ParamInfo {
let argumentList = enumConstructorExpr.arguments
let pointerExprArg = try getArgumentByName(argumentList, "pointer")
Expand All @@ -1015,7 +1001,8 @@ func parseCountedByEnum(
"expected string literal for 'count' parameter, got \(countExprArg)", node: countExprArg)
}
let unwrappedCountExpr = ExprSyntax(stringLiteral: countExprStringLit.representedLiteralValue!)
if let countVar = unwrappedCountExpr.as(DeclReferenceExprSyntax.self) {
let rewrittenCountExpr = rewriter.visit(unwrappedCountExpr)
if let countVar = rewrittenCountExpr.as(DeclReferenceExprSyntax.self) {
// Perform this lookup here so we can override the position to point to the string literal
// instead of line 1, column 1
do {
Expand All @@ -1025,11 +1012,11 @@ func parseCountedByEnum(
}
}
return CountedBy(
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: false,
pointerIndex: pointerExpr, count: rewrittenCountExpr, sizedBy: false,
nonescaping: false, dependencies: [], original: ExprSyntax(enumConstructorExpr))
}

func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo {
func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax, _ rewriter: CountExprRewriter) throws -> ParamInfo {
let argumentList = enumConstructorExpr.arguments
let pointerExprArg = try getArgumentByName(argumentList, "pointer")
let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg)
Expand All @@ -1039,8 +1026,9 @@ func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> P
"expected string literal for 'size' parameter, got \(sizeExprArg)", node: sizeExprArg)
}
let unwrappedCountExpr = ExprSyntax(stringLiteral: sizeExprStringLit.representedLiteralValue!)
let rewrittenCountExpr = rewriter.visit(unwrappedCountExpr)
return CountedBy(
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: true, nonescaping: false,
pointerIndex: pointerExpr, count: rewrittenCountExpr, sizedBy: true, nonescaping: false,
dependencies: [], original: ExprSyntax(enumConstructorExpr))
}

Expand Down Expand Up @@ -1177,7 +1165,7 @@ func parseCxxSpansInSignature(
}

func parseMacroParam(
_ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax,
_ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax, _ rewriter: CountExprRewriter,
nonescapingPointers: inout Set<Int>,
lifetimeDependencies: inout [SwiftifyExpr: [LifetimeDependence]]
) throws -> ParamInfo? {
Expand All @@ -1188,8 +1176,8 @@ func parseMacroParam(
}
let enumName = try parseEnumName(paramExpr)
switch enumName {
case "countedBy": return try parseCountedByEnum(enumConstructorExpr, signature)
case "sizedBy": return try parseSizedByEnum(enumConstructorExpr)
case "countedBy": return try parseCountedByEnum(enumConstructorExpr, signature, rewriter)
case "sizedBy": return try parseSizedByEnum(enumConstructorExpr, rewriter)
case "endedBy": return try parseEndedByEnum(enumConstructorExpr)
case "nonescaping":
let index = try parseNonEscaping(enumConstructorExpr)
Expand Down Expand Up @@ -1438,7 +1426,7 @@ func paramLifetimeAttributes(
if !isMutableSpan(param.type) {
continue
}
let paramName = param.secondName ?? param.firstName
let paramName = param.name
if containsLifetimeAttr(oldAttrs, for: paramName) {
continue
}
Expand All @@ -1456,6 +1444,61 @@ func paramLifetimeAttributes(
return defaultLifetimes
}

class CountExprRewriter: SyntaxRewriter {
public let nameMap: [String: String]

init(_ renamedParams: [String: String]) {
nameMap = renamedParams
}

override func visit(_ node: DeclReferenceExprSyntax) -> ExprSyntax {
if let newName = nameMap[node.baseName.trimmed.text] {
return ExprSyntax(
node.with(
\.baseName,
.identifier(
newName, leadingTrivia: node.baseName.leadingTrivia,
trailingTrivia: node.baseName.trailingTrivia)))
}
return escapeIfNeeded(node)
}
}

func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDeclSyntax, CountExprRewriter) {
let params = funcDecl.signature.parameterClause.parameters
let funcName = funcDecl.name.withoutBackticks.trimmed.text
let shouldRename = params.contains(where: { param in
let paramName = param.name.trimmed.text
return paramName == "_" || paramName == funcName || "`\(paramName)`" == funcName
})
var renamedParams: [String: String] = [:]
let newParams = params.enumerated().map { (i, param) in
let secondName = if shouldRename {
// Including funcName in name prevents clash with function name.
// Renaming all parameters if one requires renaming guarantees that other parameters don't clash with the renamed one.
TokenSyntax("_\(raw: funcName)_param\(raw: i)")
} else {
param.secondName?.escapeIfNeeded
}
let firstName = param.firstName.escapeIfNeeded
let newParam = param.with(\.secondName, secondName)
.with(\.firstName, firstName)
let newName = newParam.name.trimmed.text
let oldName = param.name.trimmed.text
if newName != oldName {
renamedParams[oldName] = newName
}
return newParam
}
let newDecl = if renamedParams.count > 0 {
funcDecl.with(\.signature.parameterClause.parameters, FunctionParameterListSyntax(newParams))
} else {
// Keeps source locations for diagnostics, in the common case where nothing was renamed
funcDecl
}
return (newDecl, CountExprRewriter(renamedParams))
}

/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
/// Depends on bounds, escapability and lifetime information for each pointer.
/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
Expand All @@ -1469,9 +1512,10 @@ public struct SwiftifyImportMacro: PeerMacro {
in context: some MacroExpansionContext
) throws -> [DeclSyntax] {
do {
guard let funcDecl = declaration.as(FunctionDeclSyntax.self) else {
guard let origFuncDecl = declaration.as(FunctionDeclSyntax.self) else {
throw DiagnosticError("@_SwiftifyImport only works on functions", node: declaration)
}
let (funcDecl, rewriter) = renameParameterNamesIfNeeded(origFuncDecl)

let argumentList = node.arguments!.as(LabeledExprListSyntax.self)!
var arguments = [LabeledExprSyntax](argumentList)
Expand All @@ -1487,7 +1531,7 @@ public struct SwiftifyImportMacro: PeerMacro {
var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:]
var parsedArgs = try arguments.compactMap {
try parseMacroParam(
$0, funcDecl.signature, nonescapingPointers: &nonescapingPointers,
$0, funcDecl.signature, rewriter, nonescapingPointers: &nonescapingPointers,
lifetimeDependencies: &lifetimeDependencies)
}
parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcDecl.signature, typeMappings))
Expand Down Expand Up @@ -1627,3 +1671,33 @@ extension TypeSyntaxProtocol {
return false
}
}

extension FunctionParameterSyntax {
var name: TokenSyntax {
self.secondName ?? self.firstName
}
}

extension TokenSyntax {
public var withoutBackticks: TokenSyntax {
return .identifier(self.identifier!.name)
}

public var escapeIfNeeded: TokenSyntax {
var parser = Parser("let \(self)")
let decl = DeclSyntax.parse(from: &parser)
if !decl.hasError {
return self
} else {
return self.copyTrivia(to: "`\(raw: self.trimmed.text)`")
}
}

public func copyTrivia(to other: TokenSyntax) -> TokenSyntax {
return .identifier(other.text, leadingTrivia: self.leadingTrivia, trailingTrivia: self.trailingTrivia)
}
}

func escapeIfNeeded(_ identifier: DeclReferenceExprSyntax) -> ExprSyntax {
return "\(identifier.baseName.escapeIfNeeded)"
}
55 changes: 55 additions & 0 deletions test/Interop/C/swiftify-import/Inputs/counted-by-noescape.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,58 @@ int * __counted_by(len) __noescape returnPointer(int len);
int * __counted_by(len1) returnLifetimeBound(int len1, int len2, int * __counted_by(len2) p __lifetimebound);

void anonymous(int len, int * __counted_by(len) _Nullable __noescape);

void keyword(int len, int * __counted_by(len) _Nullable func __noescape,
int extension,
int init,
int open,
int var,
int is,
int as,
int in,
int guard,
int where
);

void pointerName(int len, int * __counted_by(len) _Nullable pointerName __noescape);

void lenName(int lenName, int size, int * __counted_by(lenName * size) _Nullable p __noescape);

void func(int len, int * __counted_by(len) _Nullable func __noescape);

void *funcRenameKeyword(int len, int * __counted_by(len) _Nullable func __noescape,
int extension __lifetimebound,
int init,
int open,
int var,
int is,
int as,
int in,
int guard,
int where) __attribute__((swift_name("funcRenamed(len:func:extension:init:open:var:is:as:in:guard:where:)")));

void *funcRenameKeywordAnonymous(int len, int * __counted_by(len) _Nullable __noescape,
int __lifetimebound,
int,
int,
int,
int,
int,
int,
int,
int) __attribute__((swift_name("funcRenamedAnon(len:func:extension:init:open:var:is:as:in:guard:where:)")));

void funcRenameClash(int len, int * __counted_by(len) _Nullable func __noescape, int where)
__attribute__((swift_name("clash(len:func:clash:)")));

void funcRenameClashKeyword(int len, int * __counted_by(len) _Nullable func __noescape, int where)
__attribute__((swift_name("open(len:func:open:)")));

void funcRenameClashAnonymous(int len, int * __counted_by(len) _Nullable func __noescape, int)
__attribute__((swift_name("clash2(len:func:clash2:)")));

void funcRenameClashKeywordAnonymous(int len, int * __counted_by(len) _Nullable func __noescape, int)
__attribute__((swift_name("in(len:func:in:)")));

typedef struct actor_ *actor;
actor _Nonnull keywordType(int len, actor * __counted_by(len) __noescape p, actor _Nonnull p2);
Loading