diff --git a/Sources/Vapor/Authentication/RedirectMiddleware.swift b/Sources/Vapor/Authentication/RedirectMiddleware.swift index 0d3dee6645..ebf3b90d76 100755 --- a/Sources/Vapor/Authentication/RedirectMiddleware.swift +++ b/Sources/Vapor/Authentication/RedirectMiddleware.swift @@ -4,7 +4,15 @@ extension Authenticatable { /// - parameters: /// - path: The path to redirect to if the request is not authenticated public static func redirectMiddleware(path: String) -> Middleware { - return RedirectMiddleware(Self.self, path: path) + self.redirectMiddleware(makePath: { _ in path }) + } + + /// Basic middleware to redirect unauthenticated requests to the supplied path + /// + /// - parameters: + /// - makePath: The closure that returns the redirect path based on the given `Request` object + public static func redirectMiddleware(makePath: @escaping (Request) -> String) -> Middleware { + RedirectMiddleware(Self.self, makePath: makePath) } } @@ -12,10 +20,10 @@ extension Authenticatable { private final class RedirectMiddleware: Middleware where A: Authenticatable { - let path: String - - init(_ authenticatableType: A.Type = A.self, path: String) { - self.path = path + let makePath: (Request) -> String + + init(_ authenticatableType: A.Type = A.self, makePath: @escaping (Request) -> String) { + self.makePath = makePath } /// See Middleware.respond @@ -23,7 +31,8 @@ private final class RedirectMiddleware: Middleware if req.auth.has(A.self) { return next.respond(to: req) } - let redirect = req.redirect(to: path) + + let redirect = req.redirect(to: self.makePath(req)) return req.eventLoop.makeSucceededFuture(redirect) } } diff --git a/Tests/VaporTests/AuthenticationTests.swift b/Tests/VaporTests/AuthenticationTests.swift index f0951867b7..7b14012db6 100755 --- a/Tests/VaporTests/AuthenticationTests.swift +++ b/Tests/VaporTests/AuthenticationTests.swift @@ -76,6 +76,50 @@ final class AuthenticationTests: XCTestCase { XCTAssertEqual(res.body.string, "Vapor") } } + + func testBasicAuthenticatorWithRedirect() throws { + struct Test: Authenticatable { + static func authenticator() -> Authenticator { + TestAuthenticator() + } + + var name: String + } + + struct TestAuthenticator: BasicAuthenticator { + typealias User = Test + + func authenticate(basic: BasicAuthorization, for request: Request) -> EventLoopFuture { + if basic.username == "test" && basic.password == "secret" { + let test = Test(name: "Vapor") + request.auth.login(test) + } + return request.eventLoop.makeSucceededFuture(()) + } + } + + let app = Application(.testing) + defer { app.shutdown() } + + let redirectMiddleware = Test.redirectMiddleware { req -> String in + return "/redirect?orig=\(req.url.path)" + } + + app.routes.grouped([ + Test.authenticator(), redirectMiddleware + ]).get("test") { req -> String in + return try req.auth.require(Test.self).name + } + + let basic = "test:secret".data(using: .utf8)!.base64EncodedString() + try app.testable().test(.GET, "/test") { res in + XCTAssertEqual(res.status, .seeOther) + XCTAssertEqual(res.headers["Location"].first, "/redirect?orig=/test") + }.test(.GET, "/test", headers: ["Authorization": "Basic \(basic)"]) { res in + XCTAssertEqual(res.status, .ok) + XCTAssertEqual(res.body.string, "Vapor") + } + } func testSessionAuthentication() throws { struct Test: Authenticatable, SessionAuthenticatable {