Skip to content

Commit

Permalink
Async Serve Command (#3190)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xTim committed May 13, 2024
1 parent d9fa0d3 commit 5bc1dfa
Show file tree
Hide file tree
Showing 20 changed files with 260 additions and 384 deletions.
2 changes: 1 addition & 1 deletion Sources/Vapor/Application.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public final class Application: Sendable {
self.servers.use(.http)
self.clients.initialize()
self.clients.use(.http)
self.commands.use(self.servers.command, as: "serve", isDefault: true)
self.asyncCommands.use(self.servers.command, as: "serve", isDefault: true)
self.asyncCommands.use(RoutesCommand(), as: "routes")
}

Expand Down
18 changes: 9 additions & 9 deletions Sources/Vapor/Commands/ServeCommand.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import NIOConcurrencyHelpers
/// $ swift run Run serve
/// Server starting on http://localhost:8080
///
public final class ServeCommand: Command, Sendable {
public final class ServeCommand: AsyncCommand, Sendable {
public struct Signature: CommandSignature, Sendable {
@Option(name: "hostname", short: "H", help: "Set the hostname the server will run on.")
var hostname: String?
Expand All @@ -30,10 +30,10 @@ public final class ServeCommand: Command, Sendable {
case incompatibleFlags
}

/// See `Command`.
// See `AsyncCommand`.
public let signature = Signature()

/// See `Command`.
// See `AsyncCommand`.
public var help: String {
return "Begins serving the app over HTTP."
}
Expand All @@ -53,23 +53,23 @@ public final class ServeCommand: Command, Sendable {
self.box = .init(box)
}

/// See `Command`.
public func run(using context: CommandContext, signature: Signature) throws {
// See `AsyncCommand`.
public func run(using context: CommandContext, signature: Signature) async throws {
switch (signature.hostname, signature.port, signature.bind, signature.socketPath) {
case (.none, .none, .none, .none): // use defaults
try context.application.server.start(address: nil)
try await context.application.server.start(address: nil)

case (.none, .none, .none, .some(let socketPath)): // unix socket
try context.application.server.start(address: .unixDomainSocket(path: socketPath))
try await context.application.server.start(address: .unixDomainSocket(path: socketPath))

case (.none, .none, .some(let address), .none): // bind ("hostname:port")
let hostname = address.split(separator: ":").first.flatMap(String.init)
let port = address.split(separator: ":").last.flatMap(String.init).flatMap(Int.init)

try context.application.server.start(address: .hostname(hostname, port: port))
try await context.application.server.start(address: .hostname(hostname, port: port))

case (let hostname, let port, .none, .none): // hostname / port
try context.application.server.start(address: .hostname(hostname, port: port))
try await context.application.server.start(address: .hostname(hostname, port: port))

default: throw Error.incompatibleFlags
}
Expand Down
26 changes: 15 additions & 11 deletions Sources/Vapor/HTTP/Server/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -382,17 +382,8 @@ public final class HTTPServer: Server, Sendable {
configuration.address = address!
}

/// Print starting message.
let scheme = configuration.tlsConfiguration == nil ? "http" : "https"
let addressDescription: String
switch configuration.address {
case .hostname(let hostname, let port):
addressDescription = "\(scheme)://\(hostname ?? configuration.hostname):\(port ?? configuration.port)"
case .unixDomainSocket(let socketPath):
addressDescription = "\(scheme)+unix: \(socketPath)"
}

self.configuration.logger.notice("Server starting on \(addressDescription)")
/// Log starting message for debugging before attempting to start the server.
configuration.logger.debug("Server starting on \(configuration.addressDescription)")

/// Start the actual `HTTPServer`.
let serverConnection = try await HTTPServerConnection.start(
Expand All @@ -407,6 +398,19 @@ public final class HTTPServer: Server, Sendable {
precondition($0 == nil, "You can't start the server connection twice")
$0 = serverConnection
}

/// Overwrite configuration with actual address, if applicable.
/// They may differ from the provided configuation if port 0 was provided, for example.
if let localAddress = self.localAddress {
if let hostname = localAddress.hostname, let port = localAddress.port {
configuration.address = .hostname(hostname, port: port)
} else if let pathname = localAddress.pathname {
configuration.address = .unixDomainSocket(path: pathname)
}
}

/// Log started message with the actual configuration.
configuration.logger.notice("Server started on \(configuration.addressDescription)")

self.configuration = configuration
self.didStart.withLockedValue { $0 = true }
Expand Down
32 changes: 32 additions & 0 deletions Sources/Vapor/Server/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ public protocol Server: Sendable {
/// Start the server with the specified address.
/// - Parameters:
/// - address: The address to start the server with.
@available(*, noasync, message: "Use the async start() method instead.")
func start(address: BindAddress?) throws

/// Start the server with the specified address.
/// - Parameters:
/// - address: The address to start the server with.
func start(address: BindAddress?) async throws

/// Start the server with the specified hostname and port, if provided. If left blank, the server will be started with its default configuration.
/// - Deprecated: Please use `start(address: .hostname(hostname, port: port))` instead.
/// - Parameters:
Expand All @@ -17,7 +23,12 @@ public protocol Server: Sendable {
@available(*, deprecated, renamed: "start(address:)", message: "Please use `start(address: .hostname(hostname, port: port))` instead")
func start(hostname: String?, port: Int?) throws

/// Shut the server down.
@available(*, noasync, message: "Use the async start() method instead.")
func shutdown()

/// Shut the server down.
func shutdown() async
}

public enum BindAddress: Equatable, Sendable {
Expand Down Expand Up @@ -54,6 +65,27 @@ extension Server {
public func start(hostname: String?, port: Int?) throws {
try self.start(address: .hostname(hostname, port: port))
}

/// A default implementation for those servers that haven't migrated yet
@available(*, deprecated, message: "Implement an async version of this yourself")
public func start(address: BindAddress?) async throws {
try self.syncStart(address: address)
}

/// A default implementation for those servers that haven't migrated yet
@available(*, deprecated, message: "Implement an async version of this yourself")
public func shutdown() async {
self.syncShutdown()
}

// Trick the compiler
private func syncStart(address: BindAddress?) throws {
try self.start(address: address)
}

private func syncShutdown() {
self.shutdown()
}
}

/// Errors that may be thrown when starting a server
Expand Down
68 changes: 37 additions & 31 deletions Sources/XCTVapor/XCTApplication.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,41 +89,47 @@ extension Application {
}

func performTest(request: XCTHTTPRequest) async throws -> XCTHTTPResponse {
try app.server.start(address: .hostname(self.hostname, port: self.port))
defer { app.server.shutdown() }

try await app.server.start(address: .hostname(self.hostname, port: self.port))
let client = HTTPClient(eventLoopGroup: MultiThreadedEventLoopGroup.singleton)
defer { try! client.syncShutdown() }
var path = request.url.path
path = path.hasPrefix("/") ? path : "/\(path)"

let actualPort: Int

if self.port == 0 {
guard let portAllocated = app.http.server.shared.localAddress?.port else {
throw Abort(.internalServerError, reason: "Failed to get port from local address")
do {
var path = request.url.path
path = path.hasPrefix("/") ? path : "/\(path)"

let actualPort: Int

if self.port == 0 {
guard let portAllocated = app.http.server.shared.localAddress?.port else {
throw Abort(.internalServerError, reason: "Failed to get port from local address")
}
actualPort = portAllocated
} else {
actualPort = self.port
}
actualPort = portAllocated
} else {
actualPort = self.port
}

var url = "http://\(self.hostname):\(actualPort)\(path)"
if let query = request.url.query {
url += "?\(query)"

var url = "http://\(self.hostname):\(actualPort)\(path)"
if let query = request.url.query {
url += "?\(query)"
}
var clientRequest = HTTPClientRequest(url: url)
clientRequest.method = request.method
clientRequest.headers = request.headers
clientRequest.body = .bytes(request.body)
let response = try await client.execute(clientRequest, timeout: .seconds(30))
// Collect up to 1MB
let responseBody = try await response.body.collect(upTo: 1024 * 1024)
try await client.shutdown()
await app.server.shutdown()
return XCTHTTPResponse(
status: response.status,
headers: response.headers,
body: responseBody
)
} catch {
try? await client.shutdown()
await app.server.shutdown()
throw error
}
var clientRequest = try HTTPClient.Request(
url: url,
method: request.method,
headers: request.headers
)
clientRequest.body = .byteBuffer(request.body)
let response = try await client.execute(request: clientRequest).get()
return XCTHTTPResponse(
status: response.status,
headers: response.headers,
body: response.body ?? ByteBufferAllocator().buffer(capacity: 0)
)
}
}

Expand Down
28 changes: 12 additions & 16 deletions Tests/VaporTests/AsyncAuthTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ import Vapor
import XCTest

final class AsyncAuthenticationTests: XCTestCase {

var app: Application!

override func setUp() async throws {
app = try await Application.make(.testing)
}

override func tearDown() async throws {
try await app.asyncShutdown()
}

func testBearerAuthenticator() async throws {
struct Test: Authenticatable {
static func authenticator() -> AsyncAuthenticator {
Expand All @@ -21,9 +32,6 @@ final class AsyncAuthenticationTests: XCTestCase {
}
}

let app = try await Application.make(.testing)
defer { app.shutdown() }

app.routes.grouped([
Test.authenticator(), Test.guardMiddleware()
]).get("test") { req -> String in
Expand Down Expand Up @@ -63,9 +71,6 @@ final class AsyncAuthenticationTests: XCTestCase {
}
}

let app = try await Application.make(.testing)
defer { app.shutdown() }

app.routes.grouped([
Test.authenticator(), Test.guardMiddleware()
]).get("test") { req -> String in
Expand Down Expand Up @@ -103,10 +108,7 @@ final class AsyncAuthenticationTests: XCTestCase {
}
}
}

let app = try await Application.make(.testing)
defer { app.shutdown() }


app.routes.grouped([
Test.authenticator(), Test.guardMiddleware()
]).get("test") { req -> String in
Expand Down Expand Up @@ -142,9 +144,6 @@ final class AsyncAuthenticationTests: XCTestCase {
}
}

let app = try await Application.make(.testing)
defer { app.shutdown() }

let redirectMiddleware = Test.redirectMiddleware { req -> String in
return "/redirect?orig=\(req.url.path)"
}
Expand Down Expand Up @@ -199,9 +198,6 @@ final class AsyncAuthenticationTests: XCTestCase {
}
}

let app = try await Application.make(.testing)
defer { app.shutdown() }

app.routes.grouped([
app.sessions.middleware,
Test.sessionAuthenticator(),
Expand Down
15 changes: 10 additions & 5 deletions Tests/VaporTests/AsyncCacheTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@ import Vapor
import NIOCore

final class AsyncCacheTests: XCTestCase {
var app: Application!

override func setUp() async throws {
app = try await Application.make(.testing)
}

override func tearDown() async throws {
try await app.asyncShutdown()
}

func testInMemoryCache() async throws {
let app = try await Application.make(.testing)
defer { app.shutdown() }

let value1 = try await app.cache.get("foo", as: String.self)
XCTAssertNil(value1)
try await app.cache.set("foo", to: "bar")
Expand All @@ -33,8 +40,6 @@ final class AsyncCacheTests: XCTestCase {
}

func testCustomCache() async throws {
let app = try await Application.make(.testing)
defer { app.shutdown() }
app.caches.use(.foo)
try await app.cache.set("1", to: "2")
let value = try await app.cache.get("foo", as: String.self)
Expand Down

0 comments on commit 5bc1dfa

Please sign in to comment.