Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions FlyingSocks/Sources/Socket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ public struct Socket: Sendable, Hashable {
throw SocketError.makeFailed("CreateSocket")
}
self.file = descriptor
if type == .datagram {
try setPktInfo(domain: domain)
}
}

public var socketType: SocketType {
Expand All @@ -121,29 +118,6 @@ public struct Socket: Sendable, Hashable {
}
}

// enable return of ip_pktinfo/ipv6_pktinfo on recvmsg()
private func setPktInfo(domain: Int32) throws {
var enable = Int32(1)
let level: Int32
let name: Int32

switch domain {
case AF_INET:
level = Socket.ipproto_ip
name = Self.ip_pktinfo
case AF_INET6:
level = Socket.ipproto_ipv6
name = Self.ipv6_pktinfo
default:
return
}

let result = Socket.setsockopt(file.rawValue, level, name, &enable, socklen_t(MemoryLayout<Int32>.size))
guard result >= 0 else {
throw SocketError.makeFailed("SetPktInfoOption")
}
}

public func setValue<O: SocketOption>(_ value: O.Value, for option: O) throws {
var value = option.makeSocketValue(from: value)
let result = withUnsafeBytes(of: &value) {
Expand Down Expand Up @@ -573,6 +547,14 @@ public extension SocketOption where Self == BoolSocketOption {
BoolSocketOption(name: SO_REUSEADDR)
}

static var packetInfoIP: Self {
BoolSocketOption(level: Socket.ipproto_ip, name: Socket.ip_pktinfo)
}

static var packetInfoIPv6: Self {
BoolSocketOption(level: Socket.ipproto_ipv6, name: Socket.ipv6_pktinfo)
}

#if canImport(Darwin)
// Prevents SIG_TRAP when app is paused / running in background.
static var noSIGPIPE: Self {
Expand Down
39 changes: 39 additions & 0 deletions FlyingSocks/Tests/SocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,45 @@ struct SocketTests {
try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength)
}
}

@Test
func makes_datagram_ip4() throws {
#expect(throws: Never.self) {
try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram)
}
}

@Test
func makes_datagram_ip6() throws {
#expect(throws: Never.self) {
try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram)
}
}

@Test
func packetInfoIP() throws {
let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram)
#expect(
try socket.getValue(for: .packetInfoIP) == false
)

try socket.setValue(true, for: .packetInfoIP)
#expect(
try socket.getValue(for: .packetInfoIP) == true
)
}

@Test
func packetInfoIPv6() throws {
let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram)

withKnownIssue("Permission denied error is thrown") {
try socket.setValue(true, for: .packetInfoIPv6)
#expect(
try socket.getValue(for: .packetInfoIPv6) == true
)
}
}
}

extension Socket.Flags {
Expand Down
54 changes: 54 additions & 0 deletions FlyingSocks/XCTests/SocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,60 @@ final class SocketTests: XCTestCase {
let buffer = UnsafeMutablePointer<CChar>.allocate(capacity: Int(maxLength))
XCTAssertThrowsError(try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength))
}

func testMakes_datagram_ip4() throws {
XCTAssertNoThrow(
try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram)
)
}

func testMakes_datagram_ip6() throws {
XCTAssertNoThrow(
try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram)
)
}

func testPacketInfoIP() throws {
let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram)
XCTAssertFalse(
try socket.getValue(for: .packetInfoIP)
)

try socket.setValue(true, for: .packetInfoIP)
XCTAssertTrue(
try socket.getValue(for: .packetInfoIP)
)
}

#if canImport(Darwin)
func testPacketInfoIPv6() throws {
let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram)
XCTAssertFalse(
try socket.getValue(for: .packetInfoIPv6)
)

try XCTExpectFailure("Permission denied error is thrown") {
try socket.setValue(true, for: .packetInfoIPv6)
XCTAssertTrue(
try socket.getValue(for: .packetInfoIPv6)
)
}
}
#else
// Linux does not support XCTExpectFailure
func testPacketInfoIPv6() throws {
let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram)
do {
try socket.setValue(true, for: .packetInfoIPv6)
XCTAssertTrue(
try socket.getValue(for: .packetInfoIPv6)
)
XCTFail("Expected Failure")
} catch {
() // expected failure
}
}
#endif
}

extension Socket.Flags {
Expand Down