/
sso.go
146 lines (120 loc) · 4.2 KB
/
sso.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package api
import (
"net/http"
"github.com/crewjam/saml"
"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
)
type SingleSignOnParams struct {
ProviderID uuid.UUID `json:"provider_id"`
Domain string `json:"domain"`
RedirectTo string `json:"redirect_to"`
SkipHTTPRedirect *bool `json:"skip_http_redirect"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}
type SingleSignOnResponse struct {
URL string `json:"url"`
}
func (p *SingleSignOnParams) validate() (bool, error) {
hasProviderID := p.ProviderID != uuid.Nil
hasDomain := p.Domain != ""
if hasProviderID && hasDomain {
return hasProviderID, badRequestError("Only one of provider_id or domain supported")
} else if !hasProviderID && !hasDomain {
return hasProviderID, badRequestError("A provider_id or domain needs to be provided")
}
return hasProviderID, nil
}
// SingleSignOn handles the single-sign-on flow for a provided SSO domain or provider.
func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
params := &SingleSignOnParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
var err error
hasProviderID := false
if hasProviderID, err = params.validate(); err != nil {
return err
}
codeChallengeMethod := params.CodeChallengeMethod
codeChallenge := params.CodeChallenge
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
return err
}
flowType := getFlowFromChallenge(params.CodeChallenge)
var flowStateID *uuid.UUID
flowStateID = nil
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return err
}
if err := a.db.Create(flowState); err != nil {
return err
}
flowStateID = &flowState.ID
}
var ssoProvider *models.SSOProvider
if hasProviderID {
ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID)
if models.IsNotFoundError(err) {
return notFoundError("No such SSO provider")
} else if err != nil {
return internalServerError("Unable to find SSO provider by ID").WithInternalError(err)
}
} else {
ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain)
if models.IsNotFoundError(err) {
return notFoundError("No SSO provider assigned for this domain")
} else if err != nil {
return internalServerError("Unable to find SSO provider by domain").WithInternalError(err)
}
}
entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor()
if err != nil {
return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err)
}
// TODO: fetch new metadata if validUntil < time.Now()
serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */)
authnRequest, err := serviceProvider.MakeAuthenticationRequest(
serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding),
saml.HTTPRedirectBinding,
saml.HTTPPostBinding,
)
if err != nil {
return internalServerError("Error creating SAML Authentication Request").WithInternalError(err)
}
relayState := models.SAMLRelayState{
SSOProviderID: ssoProvider.ID,
RequestID: authnRequest.ID,
RedirectTo: params.RedirectTo,
FlowStateID: flowStateID,
}
if err := db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(&relayState); terr != nil {
return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err)
}
return nil
}); err != nil {
return err
}
ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider)
if err != nil {
return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err)
}
skipHTTPRedirect := false
if params.SkipHTTPRedirect != nil {
skipHTTPRedirect = *params.SkipHTTPRedirect
}
if skipHTTPRedirect {
return sendJSON(w, http.StatusOK, SingleSignOnResponse{
URL: ssoRedirectURL.String(),
})
}
http.Redirect(w, r, ssoRedirectURL.String(), http.StatusSeeOther)
return nil
}