Skip to content

Commit

Permalink
std.net: enable forcing non-blocking mode for accept
Browse files Browse the repository at this point in the history
Justification: It is common for non-CPU bound short routines to do
non-blocking accept to eliminate unnecessary delays before subscribing
to data, for example in hardware integration tests.
  • Loading branch information
Jan Philipp Hafer authored and Vexu committed Nov 21, 2023
1 parent 40b8c99 commit 27b34a5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
21 changes: 15 additions & 6 deletions lib/std/net.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,7 @@ pub const StreamServer = struct {
kernel_backlog: u31,
reuse_address: bool,
reuse_port: bool,
force_nonblocking: bool,

/// `undefined` until `listen` returns successfully.
listen_address: Address,
Expand All @@ -1888,6 +1889,9 @@ pub const StreamServer = struct {

/// Enable SO.REUSEPORT on the socket.
reuse_port: bool = false,

/// Force non-blocking mode.
force_nonblocking: bool = false,
};

/// After this call succeeds, resources have been acquired and must
Expand All @@ -1898,6 +1902,7 @@ pub const StreamServer = struct {
.kernel_backlog = options.kernel_backlog,
.reuse_address = options.reuse_address,
.reuse_port = options.reuse_port,
.force_nonblocking = options.force_nonblocking,
.listen_address = undefined,
};
}
Expand All @@ -1911,9 +1916,11 @@ pub const StreamServer = struct {
pub fn listen(self: *StreamServer, address: Address) !void {
const nonblock = if (std.io.is_async) os.SOCK.NONBLOCK else 0;
const sock_flags = os.SOCK.STREAM | os.SOCK.CLOEXEC | nonblock;
var use_sock_flags: u32 = sock_flags;
if (self.force_nonblocking) use_sock_flags |= os.SOCK.NONBLOCK;
const proto = if (address.any.family == os.AF.UNIX) @as(u32, 0) else os.IPPROTO.TCP;

const sockfd = try os.socket(address.any.family, sock_flags, proto);
const sockfd = try os.socket(address.any.family, use_sock_flags, proto);
self.sockfd = sockfd;
errdefer {
os.closeSocket(sockfd);
Expand Down Expand Up @@ -1963,15 +1970,18 @@ pub const StreamServer = struct {
/// The system-wide limit on the total number of open files has been reached.
SystemFdQuotaExceeded,

/// Not enough free memory. This often means that the memory allocation is limited
/// by the socket buffer limits, not by the system memory.
/// Not enough free memory. This often means that the memory allocation
/// is limited by the socket buffer limits, not by the system memory.
SystemResources,

/// Socket is not listening for new connections.
SocketNotListening,

ProtocolFailure,

/// Socket is in non-blocking mode and there is no connection to accept.
WouldBlock,

/// Firewall rules forbid connection.
BlockedByFirewall,

Expand Down Expand Up @@ -2007,9 +2017,8 @@ pub const StreamServer = struct {
.stream = Stream{ .handle = fd },
.address = accepted_addr,
};
} else |err| switch (err) {
error.WouldBlock => unreachable,
else => |e| return e,
} else |err| {
return err;
}
}
};
Expand Down
24 changes: 24 additions & 0 deletions lib/std/net/test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,27 @@ fn generateFileName(base_name: []const u8) ![]const u8 {
_ = std.fs.base64_encoder.encode(&sub_path, &random_bytes);
return std.fmt.allocPrint(testing.allocator, "{s}-{s}", .{ sub_path[0..], base_name });
}

test "non-blocking tcp server" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;

const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server = net.StreamServer.init(.{ .force_nonblocking = true });
defer server.deinit();
try server.listen(localhost);

const accept_err = server.accept();
try testing.expectError(error.WouldBlock, accept_err);

const socket_file = try net.tcpConnectToAddress(server.listen_address);
defer socket_file.close();

var client = try server.accept();
const stream = client.stream.writer();
try stream.print("hello from server\n", .{});

var buf: [100]u8 = undefined;
const len = try socket_file.read(&buf);
const msg = buf[0..len];
try testing.expect(mem.eql(u8, msg, "hello from server\n"));
}

0 comments on commit 27b34a5

Please sign in to comment.