Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Services.global #2062

Merged
merged 4 commits into from Oct 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 39 additions & 6 deletions Sources/Development/configure.swift
@@ -1,24 +1,57 @@
import Vapor

public func configure(_ s: inout Services) throws {
public func configure(_ s: inout Services) {
s.extend(Routes.self) { r, c in
try routes(r, c)
}

s.global(MemoryCache.self) { _ in
return .init()
}

s.register(HTTPServer.Configuration.self) { c in
switch c.environment {
case .tls:
return .init(hostname: "127.0.0.1", port: 8443, tlsConfiguration: tls)
return try .init(
hostname: "127.0.0.1",
port: 8443,
tlsConfiguration: .forServer(
certificateChain: [
.certificate(.init(
file: "/Users/tanner0101/dev/vapor/net-kit/certs/cert.pem",
format: .pem
))
],
privateKey: .file("/Users/tanner0101/dev/vapor/net-kit/certs/key.pem")
)
)
default:
return .init(hostname: "127.0.0.1", port: 8080)
}
}
}

let tls = TLSConfiguration.forServer(
certificateChain: [.file("/Users/tanner0101/dev/vapor/net-kit/certs/cert.pem")],
privateKey: .file("/Users/tanner0101/dev/vapor/net-kit/certs/key.pem")
)
final class MemoryCache {
var storage: [String: String]
var lock: Lock

init() {
self.storage = [:]
self.lock = .init()
}

func get(_ key: String) -> String? {
self.lock.lock()
defer { self.lock.unlock() }
return self.storage[key]
}

func set(_ key: String, to value: String?) {
self.lock.lock()
defer { self.lock.unlock() }
self.storage[key] = value
}
}

extension Environment {
static var tls: Environment {
Expand Down
20 changes: 19 additions & 1 deletion Sources/Development/routes.swift
Expand Up @@ -82,13 +82,31 @@ public func routes(_ r: Routes, _ c: Container) throws {
}

r.get("shutdown") { req -> HTTPStatus in
guard let running = try c.make(Application.self).running else {
guard let running = c.application.running else {
throw Abort(.internalServerError)
}
_ = running.stop()
return .ok
}

let cache = try c.make(MemoryCache.self)
r.get("cache", "get", ":key") { req -> String in
guard let key = req.parameters.get("key") else {
throw Abort(.internalServerError)
}
return "\(key) = \(cache.get(key) ?? "nil")"
}
r.get("cache", "set", ":key", ":value") { req -> String in
guard let key = req.parameters.get("key") else {
throw Abort(.internalServerError)
}
guard let value = req.parameters.get("value") else {
throw Abort(.internalServerError)
}
cache.set(key, to: value)
return "\(key) = \(value)"
}

r.get("hello", ":name") { req in
return req.parameters.get("name") ?? "<nil>"
}
Expand Down
47 changes: 15 additions & 32 deletions Sources/Vapor/Application.swift
Expand Up @@ -11,7 +11,7 @@ public final class Application {

private let configure: (inout Services) throws -> ()

private let threadPool: NIOThreadPool
public let threadPool: NIOThreadPool

private var didShutdown: Bool

Expand All @@ -32,7 +32,9 @@ public final class Application {

public var logger: Logger

private var _services: Services!
public var services: Services

internal var cache: ServiceCache

public struct Running {
public var onStop: EventLoopFuture<Void>
Expand All @@ -46,7 +48,7 @@ public final class Application {

public init(
environment: Environment = .development,
configure: @escaping (inout Services) throws -> () = { _ in }
configure: @escaping (inout Services) -> () = { _ in }
) {
self.environment = environment
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
Expand All @@ -57,43 +59,25 @@ public final class Application {
self.threadPool = .init(numberOfThreads: 1)
self.threadPool.start()
self.logger = .init(label: "codes.vapor.application")
var services = Services.default()
configure(&services)
self.services = services
self.cache = .init()
}

public func makeServices() throws -> Services {
var s = Services.default()
try self.configure(&s)
s.register(Application.self) { c in
return self
}
s.register(NIOThreadPool.self) { c in
return self.threadPool
}
return s
}


public func makeContainer() -> EventLoopFuture<Container> {
return self.makeContainer(on: self.eventLoopGroup.next())
}

public func makeContainer(on eventLoop: EventLoop) -> EventLoopFuture<Container> {
do {
return try _makeContainer(on: eventLoop)
} catch {
return self.eventLoopGroup.next().makeFailedFuture(error)
}
}

private func _makeContainer(on eventLoop: EventLoop) throws -> EventLoopFuture<Container> {
let s = try self.makeServices()
return Container.boot(environment: self.environment, services: s, on: eventLoop)
return Container.boot(application: self, on: eventLoop)
}

// MARK: Run

public func boot() throws {
self._services = try self.makeServices()
try self._services.providers.forEach { try $0.willBoot(self) }
try self._services.providers.forEach { try $0.didBoot(self) }
try self.services.providers.forEach { try $0.willBoot(self) }
try self.services.providers.forEach { try $0.didBoot(self) }
}

public func run() throws {
Expand Down Expand Up @@ -133,9 +117,8 @@ public final class Application {

public func shutdown() {
self.logger.debug("Application shutting down")
if self._services != nil {
self._services.providers.forEach { $0.willShutdown(self) }
}
self.services.providers.forEach { $0.willShutdown(self) }
self.cache.shutdown()
do {
try self.eventLoopGroup.syncShutdownGracefully()
} catch {
Expand Down
64 changes: 36 additions & 28 deletions Sources/Vapor/Server/HTTPServer.swift
Expand Up @@ -52,7 +52,7 @@ public final class HTTPServer: Server {
public var serverName: String?

/// Any uncaught server or responder errors will go here.
public var errorHandler: (Error) -> ()
public var logger: Logger

/// Creates a new `HTTPServerConfig`.
///
Expand All @@ -71,7 +71,7 @@ public final class HTTPServer: Server {
/// - supportPipelining: When `true`, HTTP server will support pipelined requests.
/// - serverName: If set, this name will be serialized as the `Server` header in outgoing responses.
/// - upgraders: An array of `HTTPProtocolUpgrader` to check for with each request.
/// - errorHandler: Any uncaught server or responder errors will go here.
/// - logger: Any uncaught server or responder errors will be logged here.
public init(
hostname: String = "127.0.0.1",
port: Int = 8080,
Expand All @@ -85,8 +85,8 @@ public final class HTTPServer: Server {
supportVersions: Set<HTTPVersionMajor>? = nil,
tlsConfiguration: TLSConfiguration? = nil,
serverName: String? = nil,
errorHandler: @escaping (Error) -> () = { _ in }
) {
logger: Logger? = nil
) {
self.hostname = hostname
self.port = port
self.backlog = backlog
Expand All @@ -103,7 +103,7 @@ public final class HTTPServer: Server {
}
self.tlsConfiguration = tlsConfiguration
self.serverName = serverName
self.errorHandler = errorHandler
self.logger = logger ?? Logger(label: "codes.vapor.http-server")
}
}

Expand Down Expand Up @@ -141,7 +141,6 @@ public final class HTTPServer: Server {
let scheme = self.configuration.tlsConfiguration == nil ? "http" : "https"
let address = "\(scheme)://\(configuration.hostname):\(configuration.port)"
self.application.logger.info("Server starting on \(address)")

// start the actual HTTPServer
self.connection = try HTTPServerConnection.start(
responder: self.responder,
Expand Down Expand Up @@ -238,7 +237,6 @@ private final class HTTPServerConnection {
configuration: HTTPServer.Configuration,
on eventLoopGroup: EventLoopGroup
) -> EventLoopFuture<HTTPServerConnection> {
let logger = Logger(label: "codes.vapor.http-server")
let quiesce = ServerQuiescingHelper(group: eventLoopGroup)
let bootstrap = ServerBootstrap(group: eventLoopGroup)
// Specify backlog and enable SO_REUSEADDR for the server itself
Expand Down Expand Up @@ -267,18 +265,26 @@ private final class HTTPServerConnection {
sslContext = try NIOSSLContext(configuration: tlsConfiguration)
tlsHandler = try NIOSSLServerHandler(context: sslContext)
} catch {
logger.error("Could not configure TLS: \(error)")
configuration.logger.error("Could not configure TLS: \(error)")
return channel.close(mode: .all)
}
return channel.pipeline.addHandler(tlsHandler).flatMap { (_) -> EventLoopFuture<Void> in
return channel.pipeline.configureHTTP2SecureUpgrade(h2PipelineConfigurator: { (pipeline) -> EventLoopFuture<Void> in
return channel.configureHTTP2Pipeline(mode: .server, inboundStreamStateInitializer: { (channel, streamID) -> EventLoopFuture<Void> in
return channel.pipeline.addVaporHTTP2Handlers(responder: responder, configuration: configuration, streamID: streamID)
}).flatMap { (_) -> EventLoopFuture<Void> in
return channel.pipeline.addHandler(HTTPServerErrorHandler(logger: logger))
}
}, http1PipelineConfigurator: { (pipeline) -> EventLoopFuture<Void> in
return pipeline.addVaporHTTP1Handlers(responder: responder, configuration: configuration)
return channel.pipeline.addHandler(tlsHandler).flatMap { _ in
return channel.pipeline.configureHTTP2SecureUpgrade(h2PipelineConfigurator: { pipeline in
return channel.configureHTTP2Pipeline(
mode: .server,
inboundStreamStateInitializer: { (channel, streamID) in
return channel.pipeline.addVaporHTTP2Handlers(
responder: responder,
configuration: configuration,
streamID: streamID
)
}
).map { _ in }
}, http1PipelineConfigurator: { pipeline in
return pipeline.addVaporHTTP1Handlers(
responder: responder,
configuration: configuration
)
})
}
} else {
Expand Down Expand Up @@ -338,7 +344,11 @@ final class HTTPServerErrorHandler: ChannelInboundHandler {
}

private extension ChannelPipeline {
func addVaporHTTP2Handlers(responder: Responder, configuration: HTTPServer.Configuration, streamID: HTTP2StreamID) -> EventLoopFuture<Void> {
func addVaporHTTP2Handlers(
responder: Responder,
configuration: HTTPServer.Configuration,
streamID: HTTP2StreamID
) -> EventLoopFuture<Void> {
// create server pipeline array
var handlers: [ChannelHandler] = []

Expand All @@ -359,13 +369,12 @@ private extension ChannelPipeline {
handlers.append(serverResEncoder)

// add server request -> response delegate
let handler = HTTPServerHandler(
responder: responder,
errorHandler: configuration.errorHandler
)
let handler = HTTPServerHandler(responder: responder)
handlers.append(handler)

return self.addHandlers(handlers)
return self.addHandlers(handlers).flatMap {
self.addHandler(HTTPServerErrorHandler(logger: configuration.logger))
}
}

func addVaporHTTP1Handlers(responder: Responder, configuration: HTTPServer.Configuration) -> EventLoopFuture<Void> {
Expand Down Expand Up @@ -405,10 +414,7 @@ private extension ChannelPipeline {
)
handlers.append(serverResEncoder)
// add server request -> response delegate
let handler = HTTPServerHandler(
responder: responder,
errorHandler: configuration.errorHandler
)
let handler = HTTPServerHandler(responder: responder)

// add HTTP upgrade handler
let upgrader = HTTPServerUpgradeHandler(
Expand All @@ -420,6 +426,8 @@ private extension ChannelPipeline {
handlers.append(handler)

// wait to add delegate as final step
return self.addHandlers(handlers)
return self.addHandlers(handlers).flatMap {
self.addHandler(HTTPServerErrorHandler(logger: configuration.logger))
}
}
}
7 changes: 2 additions & 5 deletions Sources/Vapor/Server/HTTPServerHandler.swift
Expand Up @@ -5,11 +5,9 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias OutboundOut = Response

let responder: Responder
let errorHandler: (Error) -> ()

init(responder: Responder, errorHandler: @escaping (Error) -> ()) {
init(responder: Responder) {
self.responder = responder
self.errorHandler = errorHandler
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
Expand All @@ -19,8 +17,7 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
self.responder.respond(to: request).whenComplete { response in
switch response {
case .failure(let error):
self.errorHandler(error)
context.close(promise: nil)
self.errorCaught(context: context, error: error)
case .success(let response):
let contentLength = response.headers.firstValue(name: .contentLength)
if request.method == .HEAD {
Expand Down