diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 48ee2407..840394bd 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -94,9 +94,6 @@ public struct Socket: Sendable, Hashable { throw SocketError.makeFailed("CreateSocket") } self.file = descriptor - if type == .datagram { - try setPktInfo(domain: domain) - } } public var socketType: SocketType { @@ -121,29 +118,6 @@ public struct Socket: Sendable, Hashable { } } - // enable return of ip_pktinfo/ipv6_pktinfo on recvmsg() - private func setPktInfo(domain: Int32) throws { - var enable = Int32(1) - let level: Int32 - let name: Int32 - - switch domain { - case AF_INET: - level = Socket.ipproto_ip - name = Self.ip_pktinfo - case AF_INET6: - level = Socket.ipproto_ipv6 - name = Self.ipv6_pktinfo - default: - return - } - - let result = Socket.setsockopt(file.rawValue, level, name, &enable, socklen_t(MemoryLayout.size)) - guard result >= 0 else { - throw SocketError.makeFailed("SetPktInfoOption") - } - } - public func setValue(_ value: O.Value, for option: O) throws { var value = option.makeSocketValue(from: value) let result = withUnsafeBytes(of: &value) { @@ -573,6 +547,14 @@ public extension SocketOption where Self == BoolSocketOption { BoolSocketOption(name: SO_REUSEADDR) } + static var packetInfoIP: Self { + BoolSocketOption(level: Socket.ipproto_ip, name: Socket.ip_pktinfo) + } + + static var packetInfoIPv6: Self { + BoolSocketOption(level: Socket.ipproto_ipv6, name: Socket.ipv6_pktinfo) + } + #if canImport(Darwin) // Prevents SIG_TRAP when app is paused / running in background. static var noSIGPIPE: Self { diff --git a/FlyingSocks/Tests/SocketTests.swift b/FlyingSocks/Tests/SocketTests.swift index 148e8e69..83163ff3 100644 --- a/FlyingSocks/Tests/SocketTests.swift +++ b/FlyingSocks/Tests/SocketTests.swift @@ -321,6 +321,45 @@ struct SocketTests { try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength) } } + + @Test + func makes_datagram_ip4() throws { + #expect(throws: Never.self) { + try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + } + } + + @Test + func makes_datagram_ip6() throws { + #expect(throws: Never.self) { + try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + } + } + + @Test + func packetInfoIP() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + #expect( + try socket.getValue(for: .packetInfoIP) == false + ) + + try socket.setValue(true, for: .packetInfoIP) + #expect( + try socket.getValue(for: .packetInfoIP) == true + ) + } + + @Test + func packetInfoIPv6() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + + withKnownIssue("Permission denied error is thrown") { + try socket.setValue(true, for: .packetInfoIPv6) + #expect( + try socket.getValue(for: .packetInfoIPv6) == true + ) + } + } } extension Socket.Flags { diff --git a/FlyingSocks/XCTests/SocketTests.swift b/FlyingSocks/XCTests/SocketTests.swift index e555e459..4e11cd65 100644 --- a/FlyingSocks/XCTests/SocketTests.swift +++ b/FlyingSocks/XCTests/SocketTests.swift @@ -258,6 +258,60 @@ final class SocketTests: XCTestCase { let buffer = UnsafeMutablePointer.allocate(capacity: Int(maxLength)) XCTAssertThrowsError(try Socket.inet_ntop(AF_INET6, &addr.sin6_addr, buffer, maxLength)) } + + func testMakes_datagram_ip4() throws { + XCTAssertNoThrow( + try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + ) + } + + func testMakes_datagram_ip6() throws { + XCTAssertNoThrow( + try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + ) + } + + func testPacketInfoIP() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET)), type: .datagram) + XCTAssertFalse( + try socket.getValue(for: .packetInfoIP) + ) + + try socket.setValue(true, for: .packetInfoIP) + XCTAssertTrue( + try socket.getValue(for: .packetInfoIP) + ) + } + + #if canImport(Darwin) + func testPacketInfoIPv6() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + XCTAssertFalse( + try socket.getValue(for: .packetInfoIPv6) + ) + + try XCTExpectFailure("Permission denied error is thrown") { + try socket.setValue(true, for: .packetInfoIPv6) + XCTAssertTrue( + try socket.getValue(for: .packetInfoIPv6) + ) + } + } + #else + // Linux does not support XCTExpectFailure + func testPacketInfoIPv6() throws { + let socket = try Socket(domain: Int32(sa_family_t(AF_INET6)), type: .datagram) + do { + try socket.setValue(true, for: .packetInfoIPv6) + XCTAssertTrue( + try socket.getValue(for: .packetInfoIPv6) + ) + XCTFail("Expected Failure") + } catch { + () // expected failure + } + } + #endif } extension Socket.Flags {