Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow gotrue to work with multiple custom domains #999

Merged
merged 14 commits into from
May 12, 2023
Merged
3 changes: 2 additions & 1 deletion internal/api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() {
req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))

*ts.Config = *c.customConfig
ts.Config.JWT = c.customConfig.JWT
ts.Config.External = c.customConfig.External
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expected, w.Code)
})
Expand Down
2 changes: 2 additions & 0 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati

r.Route("/callback", func(r *router) {
r.UseBypass(logger)
r.Use(api.isValidExternalHost)
r.Use(api.loadFlowState)

r.Get("/", api.ExternalProviderCallback)
Expand All @@ -93,6 +94,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati

r.Route("/", func(r *router) {
r.UseBypass(logger)
r.Use(api.isValidExternalHost)

r.Get("/settings", api.Settings)

Expand Down
14 changes: 14 additions & 0 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"net/url"

jwt "github.com/golang-jwt/jwt"
"github.com/supabase/gotrue/internal/models"
Expand All @@ -28,6 +29,7 @@ const (
oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token
oauthVerifierKey = contextKey("oauth_verifier")
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
)

Expand Down Expand Up @@ -235,3 +237,15 @@ func getSSOProvider(ctx context.Context) *models.SSOProvider {
}
return obj.(*models.SSOProvider)
}

func withExternalHost(ctx context.Context, u *url.URL) context.Context {
return context.WithValue(ctx, externalHostKey, u)
}

func getExternalHost(ctx context.Context) *url.URL {
obj := ctx.Value(externalHostKey)
if obj == nil {
return nil
}
return obj.(*url.URL)
}
21 changes: 20 additions & 1 deletion internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
if !emailData.Verified && !config.Mailer.Autoconfirm {
mailer := a.Mailer(ctx)
referrer := a.getReferrer(r)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil {
externalURL := getExternalHost(ctx)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil {
if errors.Is(terr, MaxFrequencyLimitError) {
return nil, tooManyRequestsError("For security purposes, you can only request this once every minute")
}
Expand Down Expand Up @@ -510,41 +511,59 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont
func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) {
config := a.config
name = strings.ToLower(name)
callbackURL := getExternalHost(ctx).String() + "/callback"

switch name {
case "apple":
config.External.Apple.RedirectURI = callbackURL
return provider.NewAppleProvider(config.External.Apple)
case "azure":
config.External.Azure.RedirectURI = callbackURL
return provider.NewAzureProvider(config.External.Azure, scopes)
case "bitbucket":
config.External.Bitbucket.RedirectURI = callbackURL
return provider.NewBitbucketProvider(config.External.Bitbucket)
case "discord":
config.External.Discord.RedirectURI = callbackURL
return provider.NewDiscordProvider(config.External.Discord, scopes)
case "github":
config.External.Github.RedirectURI = callbackURL
return provider.NewGithubProvider(config.External.Github, scopes)
case "gitlab":
config.External.Gitlab.RedirectURI = callbackURL
return provider.NewGitlabProvider(config.External.Gitlab, scopes)
case "google":
config.External.Google.RedirectURI = callbackURL
return provider.NewGoogleProvider(config.External.Google, scopes)
case "keycloak":
config.External.Keycloak.RedirectURI = callbackURL
return provider.NewKeycloakProvider(config.External.Keycloak, scopes)
case "linkedin":
config.External.Linkedin.RedirectURI = callbackURL
return provider.NewLinkedinProvider(config.External.Linkedin, scopes)
case "facebook":
config.External.Facebook.RedirectURI = callbackURL
return provider.NewFacebookProvider(config.External.Facebook, scopes)
case "notion":
config.External.Notion.RedirectURI = callbackURL
return provider.NewNotionProvider(config.External.Notion)
case "spotify":
config.External.Spotify.RedirectURI = callbackURL
return provider.NewSpotifyProvider(config.External.Spotify, scopes)
case "slack":
config.External.Slack.RedirectURI = callbackURL
return provider.NewSlackProvider(config.External.Slack, scopes)
case "twitch":
config.External.Twitch.RedirectURI = callbackURL
return provider.NewTwitchProvider(config.External.Twitch, scopes)
case "twitter":
config.External.Twitter.RedirectURI = callbackURL
return provider.NewTwitterProvider(config.External.Twitter, scopes)
case "workos":
config.External.WorkOS.RedirectURI = callbackURL
return provider.NewWorkOSProvider(config.External.WorkOS)
case "zoom":
config.External.Zoom.RedirectURI = callbackURL
return provider.NewZoomProvider(config.External.Zoom)
default:
return nil, fmt.Errorf("Provider %s could not be found", name)
Expand Down
3 changes: 2 additions & 1 deletion internal/api/invite.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {

mailer := a.Mailer(ctx)
referrer := a.getReferrer(r)
if err := sendInvite(tx, user, mailer, referrer, config.Mailer.OtpLength); err != nil {
externalURL := getExternalHost(ctx)
if err := sendInvite(tx, user, mailer, referrer, externalURL, config.Mailer.OtpLength); err != nil {
return internalServerError("Error inviting user").WithInternalError(err)
}
return nil
Expand Down
3 changes: 2 additions & 1 deletion internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {

mailer := a.Mailer(ctx)
referrer := a.getReferrer(r)
return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, flowType)
externalURL := getExternalHost(ctx)
return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType)
})
if err != nil {
if errors.Is(err, MaxFrequencyLimitError) {
Expand Down
24 changes: 13 additions & 11 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -205,7 +206,8 @@ func (a *API) GenerateLink(w http.ResponseWriter, r *http.Request) error {
return terr
}

url, terr = mailer.GetEmailActionLink(user, params.Type, referrer)
externalURL := getExternalHost(ctx)
url, terr = mailer.GetEmailActionLink(user, params.Type, referrer, externalURL)
if terr != nil {
return terr
}
Expand All @@ -228,7 +230,7 @@ func (a *API) GenerateLink(w http.ResponseWriter, r *http.Request) error {
return sendJSON(w, http.StatusOK, resp)
}

func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error {
func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
var err error
if u.ConfirmationSentAt != nil && !u.ConfirmationSentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
Expand All @@ -241,15 +243,15 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail
token := fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp)))
u.ConfirmationToken = addFlowPrefixToToken(token, flowType)
now := time.Now()
if err := mailer.ConfirmationMail(u, otp, referrerURL); err != nil {
if err := mailer.ConfirmationMail(u, otp, referrerURL, externalURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending confirmation email")
}
u.ConfirmationSentAt = &now
return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation")
}

func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, otpLength int) error {
func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error {
var err error
oldToken := u.ConfirmationToken
otp, err := crypto.GenerateOtp(otpLength)
Expand All @@ -258,7 +260,7 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re
}
u.ConfirmationToken = fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp)))
now := time.Now()
if err := mailer.InviteMail(u, otp, referrerURL); err != nil {
if err := mailer.InviteMail(u, otp, referrerURL, externalURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending invite email")
}
Expand All @@ -267,7 +269,7 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re
return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite")
}

func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error {
func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
var err error
if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
Expand All @@ -281,7 +283,7 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile
token := fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp)))
u.RecoveryToken = addFlowPrefixToToken(token, flowType)
now := time.Now()
if err := mailer.RecoveryMail(u, otp, referrerURL); err != nil {
if err := mailer.RecoveryMail(u, otp, referrerURL, externalURL); err != nil {
u.RecoveryToken = oldToken
return errors.Wrap(err, "Error sending recovery email")
}
Expand Down Expand Up @@ -313,7 +315,7 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma
return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication")
}

func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error {
func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
var err error
// since Magic Link is just a recovery with a different template and behaviour
// around new users we will reuse the recovery db timer to prevent potential abuse
Expand All @@ -329,7 +331,7 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile
u.RecoveryToken = addFlowPrefixToToken(token, flowType)

now := time.Now()
if err := mailer.MagicLinkMail(u, otp, referrerURL); err != nil {
if err := mailer.MagicLinkMail(u, otp, referrerURL, externalURL); err != nil {
u.RecoveryToken = oldToken
return errors.Wrap(err, "Error sending magic link email")
}
Expand All @@ -338,7 +340,7 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile
}

// sendEmailChange sends out an email change token to the new email.
func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfiguration, u *models.User, mailer mailer.Mailer, email string, referrerURL string, otpLength int, flowType models.FlowType) error {
func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfiguration, u *models.User, mailer mailer.Mailer, email, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
var err error
if u.EmailChangeSentAt != nil && !u.EmailChangeSentAt.Add(config.SMTP.MaxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
Expand Down Expand Up @@ -366,7 +368,7 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu

u.EmailChangeConfirmStatus = zeroConfirmation
now := time.Now()
if err := mailer.EmailChangeMail(u, otpNew, otpCurrent, referrerURL); err != nil {
if err := mailer.EmailChangeMail(u, otpNew, otpCurrent, referrerURL, externalURL); err != nil {
return err
}

Expand Down
16 changes: 15 additions & 1 deletion internal/api/mail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -39,6 +40,11 @@ func (ts *MailTestSuite) SetupTest() {
models.TruncateAll(ts.API.db)

ts.Config.Mailer.SecureEmailChangeEnabled = true

// Create User
u, err := models.NewUser("12345678", "test@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating new user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new user")
}

func (ts *MailTestSuite) TestGenerateLink() {
Expand Down Expand Up @@ -108,11 +114,14 @@ func (ts *MailTestSuite) TestGenerateLink() {
},
}

customDomainUrl, err := url.ParseRequestURI("https://example.gotrue.com")
require.NoError(ts.T(), err)

for _, c := range cases {
ts.Run(c.Desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.Body))
req := httptest.NewRequest(http.MethodPost, "/admin/generate_link", &buffer)
req := httptest.NewRequest(http.MethodPost, customDomainUrl.String()+"/admin/generate_link", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
w := httptest.NewRecorder()

Expand All @@ -131,6 +140,11 @@ func (ts *MailTestSuite) TestGenerateLink() {

// check if hashed_token matches hash function of email and the raw otp
require.Equal(ts.T(), data["hashed_token"], fmt.Sprintf("%x", sha256.Sum224([]byte(c.Body.Email+data["email_otp"].(string)))))

// check if the host used in the email link matches the initial request host
u, err := url.ParseRequestURI(data["action_link"].(string))
require.NoError(ts.T(), err)
require.Equal(ts.T(), req.Host, u.Host)
})
}
}
28 changes: 28 additions & 0 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -179,6 +181,32 @@ func isIgnoreCaptchaRoute(req *http.Request) bool {
return false
}

func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) {
ctx := req.Context()
config := a.config

var u *url.URL
var err error

baseUrl := config.API.ExternalURL
xForwardedHost := req.Header.Get("X-Forwarded-Host")
xForwardedProto := req.Header.Get("X-Forwarded-Proto")
if xForwardedHost != "" && xForwardedProto != "" {
baseUrl = fmt.Sprintf("%s://%s", xForwardedProto, xForwardedHost)
} else if req.URL.Scheme != "" && req.URL.Hostname() != "" {
baseUrl = fmt.Sprintf("%s://%s", req.URL.Scheme, req.URL.Hostname())
}
if u, err = url.ParseRequestURI(baseUrl); err != nil {
// fallback to the default hostname
log := observability.GetLogEntry(req)
log.WithField("request_url", baseUrl).Warn(err)
if u, err = url.ParseRequestURI(config.API.ExternalURL); err != nil {
return ctx, err
}
}
return withExternalHost(ctx, u), nil
}

func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
ctx := req.Context()
if !a.config.SAML.Enabled {
Expand Down
30 changes: 30 additions & 0 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"

jwt "github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -229,6 +230,35 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
}
}

func (ts *MiddlewareTestSuite) TestIsValidExternalHost() {
cases := []struct {
desc string
requestURL string
expectedURL string
}{
{
desc: "Valid custom external url",
requestURL: "https://example.custom.com",
expectedURL: "https://example.custom.com",
},
}

_, err := url.ParseRequestURI("https://example.custom.com")
require.NoError(ts.T(), err)

for _, c := range cases {
ts.Run(c.desc, func() {
req := httptest.NewRequest(http.MethodPost, c.requestURL, nil)
w := httptest.NewRecorder()
ctx, err := ts.API.isValidExternalHost(w, req)
require.NoError(ts.T(), err)

externalURL := getExternalHost(ctx)
require.Equal(ts.T(), c.expectedURL, externalURL.String())
})
}
}

func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() {
cases := []struct {
desc string
Expand Down
Loading
Loading