Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 129 additions & 44 deletions lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import SwiftSyntax
import SwiftSyntaxBuilder
import SwiftSyntaxMacros

// Disable emitting 'MutableSpan' until it has landed
let enableMutableSpan = false

// avoids depending on SwiftifyImport.swift
// all instances are reparsed and reinstantiated by the macro anyways,
// so linking is irrelevant
Expand Down Expand Up @@ -213,22 +216,26 @@ func replaceBaseType(_ type: TypeSyntax, _ base: TypeSyntax) -> TypeSyntax {

// C++ type qualifiers, `const T` and `volatile T`, are encoded as fake generic
// types, `__cxxConst<T>` and `__cxxVolatile<T>` respectively. Remove those.
func dropQualifierGenerics(_ type: TypeSyntax) -> TypeSyntax {
guard let identifier = type.as(IdentifierTypeSyntax.self) else { return type }
guard let generic = identifier.genericArgumentClause else { return type }
guard let genericArg = generic.arguments.first else { return type }
guard case .type(let argType) = genericArg.argument else { return type }
// Second return value is true if __cxxConst was stripped.
func dropQualifierGenerics(_ type: TypeSyntax) -> (TypeSyntax, Bool) {
guard let identifier = type.as(IdentifierTypeSyntax.self) else { return (type, false) }
guard let generic = identifier.genericArgumentClause else { return (type, false) }
guard let genericArg = generic.arguments.first else { return (type, false) }
guard case .type(let argType) = genericArg.argument else { return (type, false) }
switch identifier.name.text {
case "__cxxConst", "__cxxVolatile":
case "__cxxConst":
let (retType, _) = dropQualifierGenerics(argType)
return (retType, true)
case "__cxxVolatile":
return dropQualifierGenerics(argType)
default:
return type
return (type, false)
}
}

// The generated type names for template instantiations sometimes contain
// encoded qualifiers for disambiguation purposes. We need to remove those.
func dropCxxQualifiers(_ type: TypeSyntax) -> TypeSyntax {
func dropCxxQualifiers(_ type: TypeSyntax) -> (TypeSyntax, Bool) {
if let attributed = type.as(AttributedTypeSyntax.self) {
return dropCxxQualifiers(attributed.baseType)
}
Expand Down Expand Up @@ -272,12 +279,20 @@ func getUnqualifiedStdName(_ type: String) -> String? {
func getSafePointerName(mut: Mutability, generateSpan: Bool, isRaw: Bool) -> TokenSyntax {
switch (mut, generateSpan, isRaw) {
case (.Immutable, true, true): return "RawSpan"
case (.Mutable, true, true): return "MutableRawSpan"
case (.Mutable, true, true): return if enableMutableSpan {
"MutableRawSpan"
} else {
"RawSpan"
}
case (.Immutable, false, true): return "UnsafeRawBufferPointer"
case (.Mutable, false, true): return "UnsafeMutableRawBufferPointer"

case (.Immutable, true, false): return "Span"
case (.Mutable, true, false): return "MutableSpan"
case (.Mutable, true, false): return if enableMutableSpan {
"MutableSpan"
} else {
"Span"
}
case (.Immutable, false, false): return "UnsafeBufferPointer"
case (.Mutable, false, false): return "UnsafeMutableBufferPointer"
}
Expand Down Expand Up @@ -317,6 +332,28 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
return try replaceTypeName(prev, token)
}

func isMutablePointerType(_ type: TypeSyntax) -> Bool {
if let optType = type.as(OptionalTypeSyntax.self) {
return isMutablePointerType(optType.wrappedType)
}
if let impOptType = type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
return isMutablePointerType(impOptType.wrappedType)
}
if let attrType = type.as(AttributedTypeSyntax.self) {
return isMutablePointerType(attrType.baseType)
}
do {
let name = try getTypeName(type)
let text = name.text
guard let kind: Mutability = getPointerMutability(text: text) else {
return false
}
return kind == .Mutable
} catch _ {
return false
}
}

protocol BoundsCheckedThunkBuilder {
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item]
Expand Down Expand Up @@ -401,7 +438,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
}
}

struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder
public let index: Int
public let signature: FunctionSignatureSyntax
Expand All @@ -417,17 +454,7 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
-> (FunctionSignatureSyntax, Bool) {
var types = argTypes
let typeName = getUnattributedType(oldType).description
guard let desugaredType = typeMappings[typeName] else {
throw DiagnosticError(
"unable to desugar type with name '\(typeName)'", node: node)
}

let parsedDesugaredType = TypeSyntax("\(raw: getUnqualifiedStdName(desugaredType)!)")
let genericArg = TypeSyntax(parsedDesugaredType.as(IdentifierTypeSyntax.self)!
.genericArgumentClause!.arguments.first!.argument)!
types[index] = replaceBaseType(param.type,
TypeSyntax("Span<\(raw: dropCxxQualifiers(genericArg))>"))
types[index] = try newType
return try base.buildFunctionSignature(types, returnType)
}

Expand All @@ -440,44 +467,100 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
}
}

struct CxxSpanReturnThunkBuilder: BoundsCheckedThunkBuilder {
struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder
public let signature: FunctionSignatureSyntax
public let typeMappings: [String: String]
public let node: SyntaxProtocol

var oldType: TypeSyntax {
return signature.returnClause!.type
}

func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
return []
}

func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
-> (FunctionSignatureSyntax, Bool) {
assert(returnType == nil)
let typeName = getUnattributedType(signature.returnClause!.type).description
guard let desugaredType = typeMappings[typeName] else {
throw DiagnosticError(
"unable to desugar type with name '\(typeName)'", node: node)
}
let parsedDesugaredType = TypeSyntax("\(raw: getUnqualifiedStdName(desugaredType)!)")
let genericArg = TypeSyntax(parsedDesugaredType.as(IdentifierTypeSyntax.self)!
.genericArgumentClause!.arguments.first!.argument)!
let newType = replaceBaseType(signature.returnClause!.type,
TypeSyntax("Span<\(raw: dropCxxQualifiers(genericArg))>"))
return try base.buildFunctionSignature(argTypes, newType)
}

func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
let call = try base.buildFunctionCall(pointerArgs)
return "_cxxOverrideLifetime(Span(_unsafeCxxSpan: \(call)), copying: ())"
let (_, isConst) = dropCxxQualifiers(try genericArg)
let cast = if isConst || !enableMutableSpan {
"Span"
} else {
"MutableSpan"
}
return "_cxxOverrideLifetime(\(raw: cast)(_unsafeCxxSpan: \(call)), copying: ())"
}
}

protocol PointerBoundsThunkBuilder: BoundsCheckedThunkBuilder {
protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder {
var oldType: TypeSyntax { get }
var newType: TypeSyntax { get throws }
var nullable: Bool { get }
var signature: FunctionSignatureSyntax { get }
var nonescaping: Bool { get }
}

protocol SpanBoundsThunkBuilder: BoundsThunkBuilder {
var typeMappings: [String: String] { get }
var node: SyntaxProtocol { get }
}
extension SpanBoundsThunkBuilder {
var desugaredType: TypeSyntax {
get throws {
let typeName = try getUnattributedType(oldType).description
guard let desugaredTypeName = typeMappings[typeName] else {
throw DiagnosticError(
"unable to desugar type with name '\(typeName)'", node: node)
}
return TypeSyntax("\(raw: getUnqualifiedStdName(desugaredTypeName)!)")
}
}
var genericArg: TypeSyntax {
get throws {
guard let idType = try desugaredType.as(IdentifierTypeSyntax.self) else {
throw DiagnosticError(
"unexpected non-identifier type '\(try desugaredType)', expected a std::span type",
node: try desugaredType)
}
guard let genericArgumentClause = idType.genericArgumentClause else {
throw DiagnosticError(
"missing generic type argument clause expected after \(idType)", node: idType)
}
guard let firstArg = genericArgumentClause.arguments.first else {
throw DiagnosticError(
"expected at least 1 generic type argument for std::span type '\(idType)', found '\(genericArgumentClause)'",
node: genericArgumentClause.arguments)
}
guard let arg = TypeSyntax(firstArg.argument) else {
throw DiagnosticError(
"invalid generic type argument '\(firstArg.argument)'",
node: firstArg.argument)
}
return arg
}
}
var newType: TypeSyntax {
get throws {
let (strippedArg, isConst) = dropCxxQualifiers(try genericArg)
let mutablePrefix = if isConst || !enableMutableSpan {
""
} else {
"Mutable"
}
return replaceBaseType(
oldType,
TypeSyntax("\(raw: mutablePrefix)Span<\(raw: strippedArg)>"))
}
}
}

protocol PointerBoundsThunkBuilder: BoundsThunkBuilder {
var nullable: Bool { get }
var isSizedBy: Bool { get }
var generateSpan: Bool { get }
}
Expand All @@ -490,13 +573,12 @@ extension PointerBoundsThunkBuilder {
}
}

protocol ParamPointerBoundsThunkBuilder: PointerBoundsThunkBuilder {
protocol ParamBoundsThunkBuilder: BoundsThunkBuilder {
var index: Int { get }
var nonescaping: Bool { get }
}

extension ParamPointerBoundsThunkBuilder {
var generateSpan: Bool { nonescaping }

extension ParamBoundsThunkBuilder {
var param: FunctionParameterSyntax {
return getParam(signature, index)
}
Expand All @@ -518,7 +600,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
public let isSizedBy: Bool
public let dependencies: [LifetimeDependence]

var generateSpan: Bool { !dependencies.isEmpty }
var generateSpan: Bool { !dependencies.isEmpty && (!isMutablePointerType(oldType) || enableMutableSpan)}

var oldType: TypeSyntax {
return signature.returnClause!.type
Expand All @@ -531,7 +613,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
}

func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
return []
return try base.buildBoundsChecks()
}

func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
Expand All @@ -548,7 +630,8 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
}
}

struct CountedOrSizedPointerThunkBuilder: ParamPointerBoundsThunkBuilder {

struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder
public let index: Int
public let countExpr: ExprSyntax
Expand All @@ -557,6 +640,8 @@ struct CountedOrSizedPointerThunkBuilder: ParamPointerBoundsThunkBuilder {
public let isSizedBy: Bool
public let skipTrivialCount: Bool

var generateSpan: Bool { nonescaping && (!isMutablePointerType(oldType) || enableMutableSpan) }

func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
-> (FunctionSignatureSyntax, Bool) {
var types = argTypes
Expand Down
16 changes: 9 additions & 7 deletions test/Interop/C/swiftify-import/counted-by-noescape.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// REQUIRES: swift_feature_SafeInteropWrappers
// REQUIRES: swift_feature_LifetimeDependence

// This emits UnsafeMutableBufferPointer until MutableSpan has landed

// RUN: %target-swift-ide-test -print-module -module-to-print=CountedByNoEscapeClang -plugin-path %swift-plugin-dir -I %S/Inputs -source-filename=x -enable-experimental-feature SafeInteropWrappers -enable-experimental-feature LifetimeDependence | %FileCheck %s

// swift-ide-test doesn't currently typecheck the macro expansions, so run the compiler as well
Expand All @@ -11,15 +13,15 @@

import CountedByNoEscapeClang

// CHECK: @_alwaysEmitIntoClient public func complexExpr(_ len: Int{{.*}}, _ offset: Int{{.*}}, _ p: MutableSpan<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func nonnull(_ p: MutableSpan<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func nullUnspecified(_ p: MutableSpan<Int{{.*}}>)
// CHECK: @_alwaysEmitIntoClient public func complexExpr(_ len: Int{{.*}}, _ offset: Int{{.*}}, _ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func nonnull(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func nullUnspecified(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @lifetime(copy p)
// CHECK-NEXT: @_alwaysEmitIntoClient public func returnLifetimeBound(_ len1: Int32, _ p: MutableSpan<Int32>) -> MutableSpan<Int32>
// CHECK-NEXT: @_alwaysEmitIntoClient public func returnLifetimeBound(_ len1: Int32, _ p: UnsafeMutableBufferPointer<Int32>) -> UnsafeMutableBufferPointer<Int32>
// CHECK-NEXT: @_alwaysEmitIntoClient @_disfavoredOverload public func returnPointer(_ len: Int{{.*}}) -> UnsafeMutableBufferPointer<Int{{.*}}>
// CHECK-NEXT: @_alwaysEmitIntoClient public func shared(_ len: Int{{.*}}, _ p1: MutableSpan<Int{{.*}}>, _ p2: MutableSpan<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func simple(_ p: MutableSpan<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func swiftAttr(_ p: MutableSpan<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func shared(_ len: Int{{.*}}, _ p1: UnsafeMutableBufferPointer<Int{{.*}}>, _ p2: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func simple(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func swiftAttr(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)

@inlinable
public func callReturnPointer() {
Expand Down
11 changes: 5 additions & 6 deletions test/Macros/SwiftifyImport/CountedBy/MutableSpan.swift
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
// REQUIRES: swift_swift_parser

// RUN: not %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1
// RUN: %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1
// RUN: %FileCheck --match-full-lines %s < %t.log

@_SwiftifyImport(.countedBy(pointer: .param(1), count: "len"), .nonescaping(pointer: .param(1)))
func myFunc(_ ptr: UnsafeMutablePointer<CInt>, _ len: CInt) {
}

// Emits UnsafeMutableBufferPointer until MutableSpan has landed

// CHECK: @_alwaysEmitIntoClient
// CHECK-NEXT: func myFunc(_ ptr: MutableSpan<CInt>) {
// CHECK-NEXT: return unsafe ptr.withUnsafeBufferPointer { _ptrPtr in
// CHECK-NEXT: return unsafe myFunc(_ptrPtr.baseAddress!, CInt(exactly: ptr.count)!)
// CHECK-NEXT: }
// CHECK-NEXT: func myFunc(_ ptr: UnsafeMutableBufferPointer<CInt>) {
// CHECK-NEXT: return unsafe myFunc(ptr.baseAddress!, CInt(exactly: ptr.count)!)
// CHECK-NEXT: }

10 changes: 5 additions & 5 deletions test/Macros/SwiftifyImport/SizedBy/MutableRawSpan.swift
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
// REQUIRES: swift_swift_parser

// RUN: not %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1
// RUN: %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1
// RUN: %FileCheck --match-full-lines %s < %t.log

@_SwiftifyImport(.sizedBy(pointer: .param(1), size: "size"), .nonescaping(pointer: .param(1)))
func myFunc(_ ptr: UnsafeMutableRawPointer, _ size: CInt) {
}

// Emits UnsafeMutableRawBufferPointer until MutableRawSpan has landed

// CHECK: @_alwaysEmitIntoClient
// CHECK-NEXT: func myFunc(_ ptr: MutableRawSpan) {
// CHECK-NEXT: return unsafe ptr.withUnsafeBytes { _ptrPtr in
// CHECK-NEXT: return unsafe myFunc(_ptrPtr.baseAddress!, CInt(exactly: ptr.byteCount)!)
// CHECK-NEXT: }
// CHECK-NEXT: func myFunc(_ ptr: UnsafeMutableRawBufferPointer) {
// CHECK-NEXT: return unsafe myFunc(ptr.baseAddress!, CInt(exactly: ptr.count)!)
// CHECK-NEXT: }