-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
oauthCallback.go
88 lines (76 loc) · 2.76 KB
/
oauthCallback.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package service
import (
"context"
"fmt"
"net/http"
"github.com/google/uuid"
"github.com/plgd-dev/hub/v2/pkg/security/oauth2"
)
func (rh *RequestHandler) handleLinkedData(ctx context.Context, data ProvisionCacheData, authCode string) (ProvisionCacheData, error) {
if !data.linkedAccount.Data.HasOrigin() {
token, err := rh.provider.Exchange(ctx, authCode)
if err != nil {
return data, fmt.Errorf("cannot exchange origin cloud authorization code for access token: %w", err)
}
data.linkedAccount.Data = data.linkedAccount.Data.SetOrigin(*token)
return data, nil
}
if !data.linkedAccount.Data.HasTarget() {
oauth := data.linkedCloud.OAuth.ToDefaultOAuth2()
ctx = data.linkedCloud.CtxWithHTTPClient(ctx, rh.tracerProvider)
token, err := oauth.Exchange(ctx, authCode)
if err != nil {
return data, fmt.Errorf("cannot exchange target cloud authorization code for access token: %w", err)
}
data.linkedAccount.Data = data.linkedAccount.Data.SetTarget(oauth2.Token{
AccessToken: oauth2.AccessToken(token.AccessToken),
Expiry: token.Expiry,
RefreshToken: token.RefreshToken,
})
return data, nil
}
return data, nil
}
func (rh *RequestHandler) oAuthCallback(w http.ResponseWriter, r *http.Request) (int, error) {
authCode := r.FormValue("code")
state := r.FormValue("state")
cacheData := rh.provisionCache.Load(state)
if cacheData == nil {
return http.StatusBadRequest, fmt.Errorf("invalid/expired OAuth state")
}
rh.provisionCache.Delete(state)
data := cacheData.Data()
newData, err := rh.handleLinkedData(r.Context(), data, authCode)
if err != nil {
return http.StatusBadRequest, err
}
if !newData.linkedAccount.Data.HasOrigin() {
return http.StatusInternalServerError, fmt.Errorf("invalid linked data state(%v)", newData.linkedAccount.Data.State)
}
if !newData.linkedAccount.Data.HasTarget() {
return rh.handleOAuth(w, r, newData.linkedAccount, newData.linkedCloud)
}
id, err := uuid.NewRandom()
if err != nil {
return http.StatusInternalServerError, err
}
newData.linkedAccount.ID = id.String()
_, _, err = rh.store.LoadOrCreateLinkedAccount(r.Context(), newData.linkedAccount)
if err != nil {
return http.StatusBadRequest, fmt.Errorf("cannot store linked account %+v: %w", newData.linkedAccount, err)
}
if newData.linkedCloud.SupportedSubscriptionEvents.NeedPullDevices() {
return http.StatusOK, nil
}
rh.triggerTask(Task{
taskType: TaskType_SubscribeToDevices,
linkedAccount: newData.linkedAccount,
linkedCloud: newData.linkedCloud,
})
return http.StatusOK, nil
}
func (rh *RequestHandler) OAuthCallback(w http.ResponseWriter, r *http.Request) {
if statusCode, err := rh.oAuthCallback(w, r); err != nil {
logAndWriteErrorResponse(fmt.Errorf("cannot process oauth callback: %w", err), statusCode, w)
}
}