-
Notifications
You must be signed in to change notification settings - Fork 16
/
oauth.go
139 lines (113 loc) · 4.22 KB
/
oauth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package oauth
import (
"context"
"errors"
"time"
"github.com/palantir/stacktrace"
"github.com/patrickmn/go-cache"
uuid "github.com/satori/go.uuid"
"github.com/tnyim/jungletv/server/auth"
"github.com/tnyim/jungletv/types"
"github.com/tnyim/jungletv/utils/transaction"
"golang.org/x/oauth2"
)
// Manager manages oauth account association requests
type Manager struct {
oauthConfigs map[types.ConnectionService]*oauth2.Config
oauthStates *cache.Cache[string, oauthStateData]
}
type ServiceCallbackFunction func(context.Context, *oauth2.Token, *types.Connection) error
type oauthStateData struct {
Service types.ConnectionService
OnCallback ServiceCallbackFunction
User auth.User
}
// NewManager returns a new oauth manager
func NewManager() *Manager {
return &Manager{
oauthConfigs: make(map[types.ConnectionService]*oauth2.Config),
oauthStates: cache.New[string, oauthStateData](2*time.Hour, 15*time.Minute),
}
}
func (m *Manager) RegisterConnectionService(service types.ConnectionService, config *oauth2.Config) {
m.oauthConfigs[service] = config
}
// ErrMaximumConnectionsReached is returned when a user has reached their maximum number of connections to a service
var ErrMaximumConnectionsReached = errors.New("maximum number of connections to this service reached")
func (m *Manager) BeginFlow(ctxCtx context.Context, service types.ConnectionService, user auth.User, callback ServiceCallbackFunction) (string, error) {
ctx, err := transaction.Begin(ctxCtx)
if err != nil {
return "", stacktrace.Propagate(err, "")
}
defer ctx.Commit() // read-only tx
existingConnections, err := types.GetConnectionsForServiceAndRewardsAddress(ctx, service, user.Address())
if err != nil {
return "", stacktrace.Propagate(err, "")
}
if max, hasMax := types.MaxConnectionsPerService[service]; hasMax && len(existingConnections) >= max {
return "", stacktrace.Propagate(ErrMaximumConnectionsReached, "")
}
oauthConfig, ok := m.oauthConfigs[service]
if !ok {
return "", stacktrace.NewError("oauth config missing for specified service")
}
oauthState := uuid.NewV4().String()
m.oauthStates.SetDefault(oauthState, oauthStateData{
Service: service,
OnCallback: callback,
User: user,
})
return oauthConfig.AuthCodeURL(oauthState), nil
}
func (m *Manager) CompleteFlow(ctxCtx context.Context, state string, code string) error {
ctx, err := transaction.Begin(ctxCtx)
if err != nil {
return stacktrace.Propagate(err, "")
}
defer ctx.Rollback()
// recover user and service
stateData, ok := m.oauthStates.Get(state)
if !ok {
return stacktrace.NewError("state not found")
}
oauthConfig, ok := m.oauthConfigs[stateData.Service]
if !ok {
return stacktrace.NewError("oauth config missing for specified service")
}
exchangeCtx, cancelFn := context.WithTimeout(ctx, 10*time.Second)
token, err := oauthConfig.Exchange(exchangeCtx, code)
cancelFn()
if err != nil {
return stacktrace.Propagate(err, "error exchanging OAuth authorization into token")
}
if !token.Valid() {
return stacktrace.Propagate(err, "retrieved invalid OAuth token")
}
existingConnections, err := types.GetConnectionsForServiceAndRewardsAddress(ctx, stateData.Service, stateData.User.Address())
if err != nil {
return stacktrace.Propagate(err, "")
}
if max, hasMax := types.MaxConnectionsPerService[stateData.Service]; hasMax && len(existingConnections) >= max {
return stacktrace.Propagate(ErrMaximumConnectionsReached, "")
}
now := time.Now()
newConnection := &types.Connection{
ID: state, // there should be no problem in reusing the nonce that was used for the OAuth state. Connection IDs are user-facing anyway
Service: stateData.Service,
RewardsAddress: stateData.User.Address(),
CreatedAt: now,
UpdatedAt: now,
OAuthRefreshToken: &token.RefreshToken,
}
if stateData.OnCallback != nil {
err = stateData.OnCallback(ctx, token, newConnection)
if err != nil {
return stacktrace.Propagate(err, "")
}
}
err = newConnection.Update(ctx)
if err != nil {
return stacktrace.Propagate(err, "")
}
return stacktrace.Propagate(ctx.Commit(), "")
}