diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 2b56d02e..58dbd473 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -123,25 +123,14 @@ public struct Socket: Sendable, Hashable { // enable return of ip_pktinfo/ipv6_pktinfo on recvmsg() private func setPktInfo(domain: Int32) throws { - var enable = Int32(1) - let level: Int32 - let name: Int32 - switch domain { case AF_INET: - level = Socket.ipproto_ip - name = Self.ip_pktinfo + try setValue(true, for: .packetInfoIP) case AF_INET6: - level = Socket.ipproto_ipv6 - name = Self.ipv6_recvpktinfo + try setValue(true, for: .packetInfoIPv6) default: return } - - let result = Socket.setsockopt(file.rawValue, level, name, &enable, socklen_t(MemoryLayout.size)) - guard result >= 0 else { - throw SocketError.makeFailed("SetPktInfoOption") - } } public func setValue(_ value: O.Value, for option: O) throws { @@ -573,6 +562,14 @@ public extension SocketOption where Self == BoolSocketOption { BoolSocketOption(name: SO_REUSEADDR) } + static var packetInfoIP: Self { + BoolSocketOption(level: Socket.ipproto_ip, name: Socket.ip_pktinfo) + } + + static var packetInfoIPv6: Self { + BoolSocketOption(level: Socket.ipproto_ipv6, name: Socket.ipv6_recvpktinfo) + } + #if canImport(Darwin) // Prevents SIG_TRAP when app is paused / running in background. static var noSIGPIPE: Self { diff --git a/FlyingSocks/Tests/SocketTests.swift b/FlyingSocks/Tests/SocketTests.swift index 148e8e69..218ab923 100644 --- a/FlyingSocks/Tests/SocketTests.swift +++ b/FlyingSocks/Tests/SocketTests.swift @@ -321,6 +321,24 @@ struct SocketTests { try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength) } } + + @Test + func makes_datagram_ip4() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + + #expect( + try socket.getValue(for: .packetInfoIP) == true + ) + } + + @Test + func makes_datagram_ip6() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + + #expect( + try socket.getValue(for: .packetInfoIPv6) == true + ) + } } extension Socket.Flags { diff --git a/FlyingSocks/XCTests/SocketTests.swift b/FlyingSocks/XCTests/SocketTests.swift index e555e459..e4962720 100644 --- a/FlyingSocks/XCTests/SocketTests.swift +++ b/FlyingSocks/XCTests/SocketTests.swift @@ -258,6 +258,20 @@ final class SocketTests: XCTestCase { let buffer = UnsafeMutablePointer.allocate(capacity: Int(maxLength)) XCTAssertThrowsError(try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength)) } + + func testMakes_datagram_ip4() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + XCTAssertTrue( + try socket.getValue(for: .packetInfoIP) + ) + } + + func testMakes_datagram_ip6() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + XCTAssertTrue( + try socket.getValue(for: .packetInfoIPv6) + ) + } } extension Socket.Flags {