/
exchangeCache.go
71 lines (60 loc) · 1.6 KB
/
exchangeCache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package service
import (
"context"
"errors"
"sync"
"github.com/plgd-dev/hub/v2/pkg/security/oauth2"
"github.com/plgd-dev/hub/v2/pkg/sync/task/future"
)
// Thread safe cache for Exchange operation.
//
// Exchange takes authorization code and returns access token. Cache keeps track of
// the last (code, oauth2.token) pair and if the authorization code for next Exchange
// call is the same as the cached value then the call is skipped and the stored token
// is returned instead.
type ExchangeCache struct {
token *future.Future
code string
mutex sync.Mutex
}
func NewExchangeCache() *ExchangeCache {
return &ExchangeCache{}
}
func (e *ExchangeCache) getFutureToken(authorizationCode string) (*future.Future, future.SetFunc) {
e.mutex.Lock()
defer e.mutex.Unlock()
if e.token == nil || e.code != authorizationCode {
f, set := future.New()
e.token = f
e.code = authorizationCode
return f, set
}
return e.token, nil
}
// Execute Exchange or returned cached value.
func (e *ExchangeCache) Execute(ctx context.Context, provider *oauth2.PlgdProvider, authorizationCode string) (*oauth2.Token, error) {
if authorizationCode == "" {
return nil, errors.New("invalid authorization code")
}
f, set := e.getFutureToken(authorizationCode)
if set == nil {
v, err := f.Get(ctx)
if err != nil {
return nil, err
}
return v.(*oauth2.Token), nil
}
token, err := provider.Exchange(ctx, authorizationCode)
set(token, err)
if err != nil {
return nil, err
}
return token, nil
}
// Clear stored value.
func (e *ExchangeCache) Clear() {
e.mutex.Lock()
defer e.mutex.Unlock()
e.code = ""
e.token = nil
}