Skip to content

Commit

Permalink
fix: refactor email sending functions (#1495)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

The overall goal of this to expose a unified interface for emails. So
that we can potentially implement the Hook as a Custom Mailer. This is
to ensure that the impact on the existing code flow is minimal and that
we can turn off the Hook easily if needed.

After this change, we can do something similar to:

```
mailer. := a.Mailer()
if a.config.Hook.Enabled {
    mailer =   a.CustomMailer
} 
```

 and have all Hook logic live  in the custom Mailer



Specific changes are:

- Removes context from mailer as it is currently unused
- pushes down mailer into respective email sending methods
- Adds remaining send methods as API methods
- Fetch OTP Length and MaxFrequency from config
- Add convenience function for checking if an email was sent within
frequency limit
- push down `externalURL` and `referrer` into send function

---------

Co-authored-by: Kang Ming <kang.ming1996@gmail.com>
  • Loading branch information
J0 and kangmingtay committed Mar 26, 2024
1 parent 96f7a68 commit 285c290
Show file tree
Hide file tree
Showing 13 changed files with 95 additions and 83 deletions.
2 changes: 1 addition & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error {
}

// Mailer returns NewMailer with the current tenant config
func (a *API) Mailer(ctx context.Context) mailer.Mailer {
func (a *API) Mailer() mailer.Mailer {
config := a.config
return mailer.NewMailer(config)
}
5 changes: 1 addition & 4 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
} else {
emailConfirmationSent := false
if decision.CandidateEmail.Email != "" {
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil {
if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil {
if errors.Is(terr, MaxFrequencyLimitError) {
return nil, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every minute")
}
Expand Down
6 changes: 1 addition & 5 deletions internal/api/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
)

func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
Expand Down Expand Up @@ -132,10 +131,7 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora
return nil, terr
}
if !userData.Metadata.EmailVerified {
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, a.config)
externalURL := getExternalHost(ctx)
if terr := sendConfirmation(tx, targetUser, mailer, a.config.SMTP.MaxFrequency, referrer, externalURL, a.config.Mailer.OtpLength, models.ImplicitFlow); terr != nil {
if terr := a.sendConfirmation(r, tx, targetUser, models.ImplicitFlow); terr != nil {
if errors.Is(terr, MaxFrequencyLimitError) {
return nil, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "For security purposes, you can only request this once every minute")
}
Expand Down
7 changes: 1 addition & 6 deletions internal/api/invite.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
)

// InviteParams are the parameters the Signup endpoint accepts
Expand All @@ -20,7 +19,6 @@ type InviteParams struct {
func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
adminUser := getAdminUser(ctx)
params := &InviteParams{}
if err := retrieveRequestParams(r, params); err != nil {
Expand Down Expand Up @@ -81,10 +79,7 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {
return terr
}

mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
if err := sendInvite(tx, user, mailer, referrer, externalURL, config.Mailer.OtpLength); err != nil {
if err := a.sendInvite(r, tx, user); err != nil {
return internalServerError("Error inviting user").WithInternalError(err)
}
return nil
Expand Down
6 changes: 1 addition & 5 deletions internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/sethvargo/go-password/password"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
)

// MagicLinkParams holds the parameters for a magic link request
Expand Down Expand Up @@ -139,10 +138,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType)
return a.sendMagicLink(r, tx, user, flowType)
})
if err != nil {
if errors.Is(err, MaxFrequencyLimitError) {
Expand Down
96 changes: 70 additions & 26 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"net/http"
"net/url"
"strings"
"time"

Expand All @@ -11,9 +10,7 @@ import (
"github.com/pkg/errors"
"github.com/sethvargo/go-password/password"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"
"github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
Expand Down Expand Up @@ -45,7 +42,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
mailer := a.Mailer(ctx)
mailer := a.Mailer()
adminUser := getAdminUser(ctx)
params := &GenerateLinkParams{}
if err := retrieveRequestParams(r, params); err != nil {
Expand Down Expand Up @@ -263,10 +260,17 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
return sendJSON(w, http.StatusOK, resp)
}

func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error {
ctx := r.Context()
mailer := a.Mailer()
config := a.config
otpLength := config.Mailer.OtpLength
maxFrequency := config.SMTP.MaxFrequency
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
var err error
if u.ConfirmationSentAt != nil && !u.ConfirmationSentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
if err := validateSentWithinFrequencyLimit(u.ConfirmationSentAt, maxFrequency); err != nil {
return err
}
oldToken := u.ConfirmationToken
otp, err := crypto.GenerateOtp(otpLength)
Expand All @@ -277,7 +281,7 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail
token := crypto.GenerateTokenHash(u.GetEmail(), otp)
u.ConfirmationToken = addFlowPrefixToToken(token, flowType)
now := time.Now()
if err := mailer.ConfirmationMail(u, otp, referrerURL, externalURL); err != nil {
if err := mailer.ConfirmationMail(r, u, otp, referrerURL, externalURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending confirmation email")
}
Expand All @@ -290,7 +294,13 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail
return nil
}

func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error {
func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User) error {
ctx := r.Context()
mailer := a.Mailer()
config := a.config
otpLength := config.Mailer.OtpLength
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
var err error
oldToken := u.ConfirmationToken
otp, err := crypto.GenerateOtp(otpLength)
Expand All @@ -300,7 +310,7 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re
}
u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp)
now := time.Now()
if err := mailer.InviteMail(u, otp, referrerURL, externalURL); err != nil {
if err := mailer.InviteMail(r, u, otp, referrerURL, externalURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending invite email")
}
Expand All @@ -314,10 +324,17 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re
return nil
}

func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error {
ctx := r.Context()
config := a.config
maxFrequency := config.SMTP.MaxFrequency
otpLength := config.Mailer.OtpLength
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
mailer := a.Mailer()
var err error
if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, maxFrequency); err != nil {
return err
}

oldToken := u.RecoveryToken
Expand All @@ -329,7 +346,7 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile
token := crypto.GenerateTokenHash(u.GetEmail(), otp)
u.RecoveryToken = addFlowPrefixToToken(token, flowType)
now := time.Now()
if err := mailer.RecoveryMail(u, otp, referrerURL, externalURL); err != nil {
if err := mailer.RecoveryMail(r, u, otp, referrerURL, externalURL); err != nil {
u.RecoveryToken = oldToken
return errors.Wrap(err, "Error sending recovery email")
}
Expand All @@ -342,10 +359,15 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile
return nil
}

func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, otpLength int) error {
func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u *models.User) error {
config := a.config
maxFrequency := config.SMTP.MaxFrequency
otpLength := config.Mailer.OtpLength
mailer := a.Mailer()
var err error
if u.ReauthenticationSentAt != nil && !u.ReauthenticationSentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError

if err := validateSentWithinFrequencyLimit(u.ReauthenticationSentAt, maxFrequency); err != nil {
return err
}

oldToken := u.ReauthenticationToken
Expand All @@ -356,7 +378,7 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma
}
u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp)
now := time.Now()
if err := mailer.ReauthenticateMail(u, otp); err != nil {
if err := mailer.ReauthenticateMail(r, u, otp); err != nil {
u.ReauthenticationToken = oldToken
return errors.Wrap(err, "Error sending reauthentication email")
}
Expand All @@ -369,13 +391,21 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma
return nil
}

func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error {
ctx := r.Context()
mailer := a.Mailer()
config := a.config
otpLength := config.Mailer.OtpLength
maxFrequency := config.SMTP.MaxFrequency
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
var err error
// since Magic Link is just a recovery with a different template and behaviour
// around new users we will reuse the recovery db timer to prevent potential abuse
if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, maxFrequency); err != nil {
return err
}

oldToken := u.RecoveryToken
otp, err := crypto.GenerateOtp(otpLength)
if err != nil {
Expand All @@ -386,7 +416,7 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile
u.RecoveryToken = addFlowPrefixToToken(token, flowType)

now := time.Now()
if err := mailer.MagicLinkMail(u, otp, referrerURL, externalURL); err != nil {
if err := mailer.MagicLinkMail(r, u, otp, referrerURL, externalURL); err != nil {
u.RecoveryToken = oldToken
return errors.Wrap(err, "Error sending magic link email")
}
Expand All @@ -400,11 +430,18 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile
}

// sendEmailChange sends out an email change token to the new email.
func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfiguration, u *models.User, mailer mailer.Mailer, email, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error {
func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models.User, email string, flowType models.FlowType) error {
ctx := r.Context()
config := a.config
otpLength := config.Mailer.OtpLength
var err error
if u.EmailChangeSentAt != nil && !u.EmailChangeSentAt.Add(config.SMTP.MaxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
mailer := a.Mailer()
if err := validateSentWithinFrequencyLimit(u.EmailChangeSentAt, config.SMTP.MaxFrequency); err != nil {
return err
}
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)

otpNew, err := crypto.GenerateOtp(otpLength)
if err != nil {
// OTP generation must succeed
Expand All @@ -427,7 +464,7 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu

u.EmailChangeConfirmStatus = zeroConfirmation
now := time.Now()
if err := mailer.EmailChangeMail(u, otpNew, otpCurrent, referrerURL, externalURL); err != nil {
if err := mailer.EmailChangeMail(r, u, otpNew, otpCurrent, referrerURL, externalURL); err != nil {
return err
}

Expand Down Expand Up @@ -457,3 +494,10 @@ func validateEmail(email string) (string, error) {
}
return strings.ToLower(email), nil
}

func validateSentWithinFrequencyLimit(sentAt *time.Time, frequency time.Duration) error {
if sentAt != nil && sentAt.Add(frequency).After(time.Now()) {
return MaxFrequencyLimitError
}
return nil
}
3 changes: 1 addition & 2 deletions internal/api/reauthenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
return terr
}
if email != "" {
mailer := a.Mailer(ctx)
return a.sendReauthenticationOtp(tx, user, mailer, config.SMTP.MaxFrequency, config.Mailer.OtpLength)
return a.sendReauthenticationOtp(r, tx, user)
} else if phone != "" {
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
Expand Down
7 changes: 1 addition & 6 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
)

// RecoverParams holds the parameters for a password recovery request
Expand Down Expand Up @@ -34,7 +33,6 @@ func (p *RecoverParams) Validate() error {
func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
params := &RecoverParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
Expand Down Expand Up @@ -66,10 +64,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType)
return a.sendPasswordRecovery(r, tx, user, flowType)
})
if err != nil {
if errors.Is(err, MaxFrequencyLimitError) {
Expand Down
8 changes: 2 additions & 6 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
)

// ResendConfirmationParams holds the parameters for a resend request
Expand Down Expand Up @@ -115,17 +114,14 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
}

messageID := ""
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)
err = db.Transaction(func(tx *storage.Connection) error {
switch params.Type {
case signupVerification:
if terr := models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", nil); terr != nil {
return terr
}
// PKCE not implemented yet
return sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow)
return a.sendConfirmation(r, tx, user, models.ImplicitFlow)
case smsVerification:
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
Expand All @@ -140,7 +136,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
}
messageID = mID
case emailChangeVerification:
return a.sendEmailChange(tx, config, user, mailer, user.EmailChange, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow)
return a.sendEmailChange(r, tx, user, user.EmailChange, models.ImplicitFlow)
case phoneChangeVerification:
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
Expand Down
Loading

0 comments on commit 285c290

Please sign in to comment.