-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
166 lines (137 loc) · 4.61 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package main
import (
"fmt"
"log"
"net/http"
"os"
"time"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
"github.com/dgrijalva/jwt-go"
"github.com/getsentry/sentry-go"
acmeserverless "github.com/retgits/acme-serverless"
wflambda "github.com/wavefronthq/wavefront-lambda-go"
)
// handler handles the API Gateway events and returns an error if anything goes wrong.
func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
// Initiialize a connection to Sentry to capture errors and traces
sentry.Init(sentry.ClientOptions{
Dsn: os.Getenv("SENTRY_DSN"),
Transport: &sentry.HTTPSyncTransport{
Timeout: time.Second * 3,
},
ServerName: os.Getenv("FUNCTION_NAME"),
Release: os.Getenv("VERSION"),
Environment: os.Getenv("STAGE"),
})
// Create headers if they don't exist and add
// the CORS required headers, otherwise the response
// will not be accepted by browsers.
headers := request.Headers
if headers == nil {
headers = make(map[string]string)
}
headers["Access-Control-Allow-Origin"] = "*"
// Create the key attributes
login, err := acmeserverless.UnmarshalLoginResponse(request.Body)
if err != nil {
return handleError("unmarshalling login", headers, err)
}
valid, id, _, err := ValidateToken(login.RefreshToken)
if !valid || id == "" {
res := acmeserverless.VerifyTokenResponse{
Message: "Invalid Key. User Not Authorized",
Status: http.StatusForbidden,
}
payload, err := res.Marshal()
if err != nil {
return handleError("marshalling response", headers, err)
}
response := events.APIGatewayProxyResponse{
StatusCode: http.StatusOK,
Body: string(payload),
Headers: headers,
}
return response, nil
}
newToken, _ := GenerateAccessToken("eric", id)
res := acmeserverless.LoginResponse{
AccessToken: newToken,
RefreshToken: login.RefreshToken,
Status: http.StatusOK,
}
payload, err := res.Marshal()
if err != nil {
return handleError("marshalling response", headers, err)
}
response := events.APIGatewayProxyResponse{
StatusCode: http.StatusOK,
Body: string(payload),
Headers: headers,
}
return response, nil
}
// handleError takes the activity where the error occured and the error object and sends a message to sentry.
// The original error, together with the appropriate API Gateway Proxy Response, is returned so it can be thrown.
func handleError(area string, headers map[string]string, err error) (events.APIGatewayProxyResponse, error) {
sentry.CaptureException(fmt.Errorf("error %s: %s", area, err.Error()))
msg := fmt.Sprintf("error %s: %s", area, err.Error())
log.Println(msg)
return events.APIGatewayProxyResponse{
StatusCode: http.StatusBadRequest,
Body: msg,
Headers: headers,
}, nil
}
// The main method is executed by AWS Lambda and points to the handler
func main() {
lambda.Start(wflambda.Wrapper(handler))
}
func GenerateAccessToken(username string, uuid string) (string, error) {
// Declare the expiration time of the access token
// Here the expiration is 5 minutes
expirationTimeAccessToken := time.Now().Add(5 * time.Minute).Unix()
// Declare the token with the algorithm used for signing, and the claims
token := jwt.New(jwt.SigningMethodHS256)
token.Header["kid"] = "signin_1"
claims := token.Claims.(jwt.MapClaims)
claims["Username"] = username
claims["exp"] = expirationTimeAccessToken
claims["sub"] = uuid
// Create the JWT string
tokenString, err := token.SignedString([]byte("my_secret_key"))
if err != nil {
return "", err
}
return tokenString, nil
}
// ValidateToken is used to validate both access_token and refresh_token. It is done based on the "Key ID" provided by the JWT
func ValidateToken(tokenString string) (bool, string, string, error) {
var key []byte
var keyID string
claims := jwt.MapClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
keyID = token.Header["kid"].(string)
// If the "kid" (Key ID) is equal to signin_1, then it is compared against access_token secret key, else if it
// is equal to signin_2 , it is compared against refresh_token secret key.
if keyID == "signin_1" {
key = []byte("my_secret_key")
} else if keyID == "signin_2" {
key = []byte("my_secret_key_2")
}
return key, nil
})
// Check if signatures are valid.
if err != nil {
if err == jwt.ErrSignatureInvalid {
log.Printf("Invalid Token Signature")
return false, "", keyID, err
}
return false, "", keyID, err
}
if !token.Valid {
log.Printf("Invalid Token")
return false, "", keyID, err
}
return true, claims["sub"].(string), keyID, nil
}