1
- import Crypto
2
1
import Vapor
3
2
4
- public typealias TokenRetrievalHandler = ( ( Request ) throws -> Future < String > )
3
+ public typealias TokenRetrievalHandler = ( ( Request ) -> EventLoopFuture < String > )
5
4
6
5
/// Middleware to protect against cross-site request forgery attacks.
7
- public struct CSRF : Middleware , Service {
6
+ public struct CSRF : Middleware {
8
7
private let ignoredMethods : [ HTTPMethod ]
9
8
private var tokenRetrieval : TokenRetrievalHandler
10
9
@@ -18,77 +17,78 @@ public struct CSRF: Middleware, Service {
18
17
self . tokenRetrieval = tokenRetrieval ?? CSRF . defaultTokenRetrieval
19
18
}
20
19
21
- public func respond( to request: Request , chainingTo next: Responder ) throws -> Future < Response > {
22
- let method = request. http . method
20
+ public func respond( to request: Request , chainingTo next: Responder ) -> EventLoopFuture < Response > {
21
+ let method = request. method
23
22
24
23
if ignoredMethods. contains ( method) {
25
- return try next. respond ( to: request)
24
+ return next. respond ( to: request)
26
25
}
27
26
28
- let secret = try createSecret ( from: request)
27
+ let secret = createSecret ( from: request)
29
28
30
- return try tokenRetrieval ( request) . flatMap ( to: Response . self) { token in
31
- let valid = try self . validate ( token, with: secret)
32
- guard valid else {
33
- throw Abort ( . forbidden, reason: " Invalid CSRF token. " )
29
+ return tokenRetrieval ( request) . flatMap { token in
30
+ do {
31
+ let valid = try self . validate ( token, with: secret)
32
+ guard valid else {
33
+ return request. eventLoop. makeFailedFuture ( Abort ( . forbidden, reason: " Invalid CSRF token. " ) )
34
+ }
35
+ return next. respond ( to: request)
36
+ } catch {
37
+ return request. eventLoop. makeFailedFuture ( error)
34
38
}
35
- return try next. respond ( to: request)
36
39
}
37
40
}
38
41
39
42
/// Creates a token from a given `Request`. Call this method to generate a CSRF token to assign to your key of choice in the header and pass the token back to the caller via the response.
40
43
/// - parameter request: The `Request` used to either find the secret in, or the request used to generate the secret.
41
44
/// - returns: `Bytes` representing the generated token.
42
45
/// - throws: An error that may arise from either creating the secret from the request or from generating the token.
43
- public func createToken( from request: Request ) throws -> String {
44
- let secret = try createSecret ( from: request)
45
- let saltBytes = try CryptoRandom ( ) . generateData ( count: 8 )
46
- let saltString = saltBytes. hexEncodedString ( )
47
- return try generateToken ( from: secret, with: saltString)
46
+ public func createToken( from request: Request ) -> String {
47
+ let secret = createSecret ( from: request)
48
+ let saltBytes = [ UInt8 ] . random ( count: 8 )
49
+ let saltString = saltBytes. description
50
+ return generateToken ( from: secret, with: saltString)
48
51
}
49
52
50
- private func generateToken( from secret: String , with salt: String ) throws -> String {
51
- let saltPlusSecret = salt + " - " + secret
52
- let token = try MD5 . hash ( saltPlusSecret) . hexEncodedString ( )
53
+ private func generateToken( from secret: String , with salt: String ) -> String {
54
+ let saltPlusSecret = ( salt + " - " + secret)
55
+ let digest = Insecure . MD5. hash ( data: [ UInt8] ( saltPlusSecret. utf8) )
56
+ let token = digest. description
53
57
return salt + " - " + token
54
58
}
55
59
56
60
private func validate( _ token: String , with secret: String ) throws -> Bool {
57
61
guard let salt = token. components ( separatedBy: " - " ) . first else {
58
62
throw Abort ( . forbidden, reason: " The provided CSRF token is in the wrong format. " )
59
63
}
60
- let expectedToken = try generateToken ( from: secret, with: salt)
64
+ let expectedToken = generateToken ( from: secret, with: salt)
61
65
return expectedToken == token
62
66
}
63
67
64
- private func createSecret( from request: Request ) throws -> String {
65
-
66
- let session = try request. session ( )
67
-
68
- guard let secret = session [ " CSRFSecret " ] else {
69
- let random = CryptoRandom ( )
70
- let secretData = try random. generateData ( count: 16 )
71
- let secret = secretData. hexEncodedString ( )
72
- session [ " CSRFSecret " ] = secret
68
+ private func createSecret( from request: Request ) -> String {
69
+ guard let secret = request. session. data [ " CSRFSecret " ] else {
70
+ let secretData = [ UInt8 ] . random ( count: 16 )
71
+ let secret = secretData. description
72
+ request. session. data [ " CSRFSecret " ] = secret
73
73
return secret
74
74
}
75
-
76
75
return secret
77
76
}
78
77
79
- private static func defaultTokenRetrieval( from request: Request ) throws -> Future < String > {
78
+ private static func defaultTokenRetrieval( from request: Request ) -> EventLoopFuture < String > {
80
79
81
80
let csrfKeys : Set < String > = [ " _csrf " , " csrf-token " , " xsrf-token " , " x-csrf-token " , " x-xsrf-token " , " x-csrftoken " ]
82
- let requestHeaderKeys = Set ( request. http . headers. map { $0. name } )
81
+ let requestHeaderKeys = Set ( request. headers. map { $0. name } )
83
82
let intersection = csrfKeys. intersection ( requestHeaderKeys)
84
83
85
- if let matchingKey = intersection. first, let token = request. http . headers [ matchingKey] . first {
86
- return request. future ( token)
84
+ if let matchingKey = intersection. first, let token = request. headers [ matchingKey] . first {
85
+ return request. eventLoop . makeSucceededFuture ( token)
87
86
}
88
87
89
- return request. content. get ( at: " _csrf " )
90
- . catchMap { error in
91
- throw Abort ( . forbidden, reason: " No CSRF token provided. " )
92
- }
88
+ do {
89
+ return try request. eventLoop. makeSucceededFuture ( request. content. get ( String . self, at: " _csrf " ) )
90
+ } catch {
91
+ return request. eventLoop. makeFailedFuture ( Abort ( . forbidden, reason: " No CSRF token provided. " ) )
92
+ }
93
93
}
94
94
}
0 commit comments