From f41e35d8d066609ceb5ec38eb1357628aea90b45 Mon Sep 17 00:00:00 2001 From: Chris Stockton Date: Thu, 28 Aug 2025 13:33:07 -0700 Subject: [PATCH 1/3] feat: decompose the mailer package - changes the flow of object construction, not behavior - API builds mailer as: noop|mailme -> validate -> tasks -> templatemailer - remove MailerOptions & mailerClientFunc options - makes it easier to form a mental model around the mailer package - sets the stage for the upcoming template reloading change --- go.mod | 2 +- internal/api/api.go | 17 +- internal/api/mail.go | 34 ++- internal/api/options.go | 7 - internal/api/options_test.go | 17 -- internal/mailer/mailer.go | 206 ++---------------- .../mailmeclient.go} | 74 ++++--- .../{noop.go => noopclient/noopclient.go} | 23 +- internal/mailer/taskclient/taskclient.go | 84 +++++++ .../templatemailer.go} | 121 +++++----- .../templatemailer_test.go} | 2 +- .../templatemailer_url_test.go} | 6 +- .../validateclient.go} | 91 ++++++-- .../validateclient_test.go} | 2 +- 14 files changed, 348 insertions(+), 338 deletions(-) rename internal/mailer/{mailme.go => mailmeclient/mailmeclient.go} (66%) rename internal/mailer/{noop.go => noopclient/noopclient.go} (56%) create mode 100644 internal/mailer/taskclient/taskclient.go rename internal/mailer/{template.go => templatemailer/templatemailer.go} (83%) rename internal/mailer/{template_test.go => templatemailer/templatemailer_test.go} (98%) rename internal/mailer/{mailer_test.go => templatemailer/templatemailer_url_test.go} (96%) rename internal/mailer/{validate.go => validateclient/validateclient.go} (78%) rename internal/mailer/{validate_test.go => validateclient/validateclient_test.go} (99%) diff --git a/go.mod b/go.mod index 6c06d4b84..e3db84713 100644 --- a/go.mod +++ b/go.mod @@ -173,7 +173,7 @@ require ( golang.org/x/net v0.38.0 // indirect golang.org/x/sync v0.12.0 golang.org/x/sys v0.31.0 - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.23.0 golang.org/x/time v0.9.0 google.golang.org/appengine v1.6.8 // indirect google.golang.org/grpc v1.63.2 // indirect diff --git a/internal/api/api.go b/internal/api/api.go index bdf11265f..c4fb7b614 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -38,11 +38,10 @@ type API struct { config *conf.GlobalConfiguration version string - hooksMgr *v0hooks.Manager - hibpClient *hibp.PwnedClient - oauthServer *oauthserver.Server - mailerClientFunc func() mailer.MailClient - tokenService *tokens.Service + hooksMgr *v0hooks.Manager + hibpClient *hibp.PwnedClient + oauthServer *oauthserver.Server + tokenService *tokens.Service // overrideTime can be used to override the clock used by handlers. Should only be used in tests! overrideTime func() time.Time @@ -100,11 +99,6 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if api.limiterOpts == nil { api.limiterOpts = NewLimiterOptions(globalConfig) } - if api.mailerClientFunc == nil { - api.mailerClientFunc = func() mailer.MailClient { - return mailer.NewMailClient(globalConfig) - } - } if api.hooksMgr == nil { httpDr := hookshttp.New() pgfuncDr := hookspgfunc.New(db) @@ -376,8 +370,7 @@ func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error { // Mailer returns NewMailer with the current tenant config func (a *API) Mailer() mailer.Mailer { - config := a.config - return mailer.NewMailerWithClient(config, a.mailerClientFunc()) + return newMailer(a.config) } // ServeHTTP implements the http.Handler interface by passing the request along diff --git a/internal/api/mail.go b/internal/api/mail.go index 569ecf726..96ed8b9ba 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -6,8 +6,16 @@ import ( "strings" "time" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/hooks/v0hooks" + "github.com/supabase/auth/internal/mailer" mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/mailer/mailmeclient" + "github.com/supabase/auth/internal/mailer/noopclient" + "github.com/supabase/auth/internal/mailer/taskclient" + "github.com/supabase/auth/internal/mailer/templatemailer" + "github.com/supabase/auth/internal/mailer/validateclient" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -23,6 +31,26 @@ import ( "github.com/supabase/auth/internal/utilities" ) +// newMailer returns a new gotrue mailer +func newMailer(globalConfig *conf.GlobalConfiguration) *templatemailer.TemplateMailer { + var mc mailer.Client + if globalConfig.SMTP.Host == "" { + logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) + mc = noopclient.New() + } else { + mc = mailmeclient.New(globalConfig) + } + + // Wrap client with validation first + mc = validateclient.New(globalConfig, mc) + + // Then background tasks + mc = taskclient.New(globalConfig, mc) + + // Finally the template mailer + return templatemailer.New(globalConfig, mc) +} + var ( EmailRateLimitExceeded error = errors.New("email rate limit exceeded") ) @@ -705,9 +733,9 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, } switch { - case errors.Is(err, mail.ErrInvalidEmailAddress), - errors.Is(err, mail.ErrInvalidEmailFormat), - errors.Is(err, mail.ErrInvalidEmailDNS): + case errors.Is(err, validateclient.ErrInvalidEmailAddress), + errors.Is(err, validateclient.ErrInvalidEmailFormat), + errors.Is(err, validateclient.ErrInvalidEmailDNS): return apierrors.NewBadRequestError( apierrors.ErrorCodeEmailAddressInvalid, "Email address %q is invalid", diff --git a/internal/api/options.go b/internal/api/options.go index c5ab33fb5..a54efe36e 100644 --- a/internal/api/options.go +++ b/internal/api/options.go @@ -6,7 +6,6 @@ import ( "github.com/didip/tollbooth/v5" "github.com/didip/tollbooth/v5/limiter" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/ratelimit" "github.com/supabase/auth/internal/tokens" ) @@ -15,12 +14,6 @@ type Option interface { apply(*API) } -type MailerOptions struct { - MailerClientFunc func() mailer.MailClient -} - -func (mo *MailerOptions) apply(a *API) { a.mailerClientFunc = mo.MailerClientFunc } - type LimiterOptions struct { Email ratelimit.Limiter Phone ratelimit.Limiter diff --git a/internal/api/options_test.go b/internal/api/options_test.go index 3785f6426..c4c1d1623 100644 --- a/internal/api/options_test.go +++ b/internal/api/options_test.go @@ -4,10 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/e2e" - "github.com/supabase/auth/internal/mailer" ) func TestNewLimiterOptions(t *testing.T) { @@ -31,17 +28,3 @@ func TestNewLimiterOptions(t *testing.T) { assert.NotNil(t, rl.SSO) assert.NotNil(t, rl.SAMLAssertion) } - -func TestMailerOptions(t *testing.T) { - globalCfg := e2e.Must(e2e.Config()) - conn := e2e.Must(e2e.Conn(globalCfg)) - - sentinelMailer := mailer.NewMailClient(globalCfg) - mailerOpts := &MailerOptions{MailerClientFunc: func() mailer.MailClient { - return sentinelMailer - }} - a := NewAPIWithVersion(globalCfg, conn, apiTestVersion, mailerOpts) - - got := a.mailerClientFunc() - require.Equal(t, sentinelMailer, got) -} diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go index 5ce2dc0ca..0b6ad383a 100644 --- a/internal/mailer/mailer.go +++ b/internal/mailer/mailer.go @@ -2,16 +2,24 @@ package mailer import ( "context" - "fmt" "net/http" "net/url" - "github.com/sirupsen/logrus" - "github.com/supabase/auth/internal/api/apitask" - "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" ) +const ( + SignupVerification = "signup" + RecoveryVerification = "recovery" + InviteVerification = "invite" + MagicLinkVerification = "magiclink" + EmailChangeVerification = "email_change" + EmailOTPVerification = "email" + EmailChangeCurrentVerification = "email_change_current" + EmailChangeNewVerification = "email_change_new" + ReauthenticationVerification = "reauthentication" +) + // Mailer defines the interface a mailer must implement. type Mailer interface { InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error @@ -23,10 +31,17 @@ type Mailer interface { GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) } -type EmailParams struct { - Token string - Type string - RedirectTo string +type Client interface { + Mail( + ctx context.Context, + to string, + subjectTemplate string, + templateURL string, + defaultTemplate string, + templateData map[string]any, + headers map[string][]string, + typ string, + ) error } type EmailData struct { @@ -38,178 +53,3 @@ type EmailData struct { TokenNew string `json:"token_new"` TokenHashNew string `json:"token_hash_new"` } - -// NewMailer returns a new gotrue mailer -func NewMailer(globalConfig *conf.GlobalConfiguration) Mailer { - mc := NewMailClient(globalConfig) - return NewMailerWithClient(globalConfig, mc) -} - -type emailValidatorMailClient struct { - ev *EmailValidator - mc MailClient -} - -// Mail implements mailer.MailClient interface by calling validate before -// passing the mail request to the next MailClient. -func (o *emailValidatorMailClient) Mail( - ctx context.Context, - to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]any, - headers map[string][]string, - typ string, -) error { - if err := o.ev.Validate(ctx, to); err != nil { - return err - } - return o.mc.Mail( - ctx, - to, - subjectTemplate, - templateURL, - defaultTemplate, - templateData, - headers, - typ, - ) -} - -// NewMailerWithClient returns a new Mailer that will use the given MailClient. -func NewMailerWithClient( - globalConfig *conf.GlobalConfiguration, - mailClient MailClient, -) Mailer { - return &TemplateMailer{ - SiteURL: globalConfig.SiteURL, - Config: globalConfig, - MailClient: mailClient, - } -} - -// NewMailClient returns a new MailClient based on the given configuration. -func NewMailClient(globalConfig *conf.GlobalConfiguration) MailClient { - mc := newMailClient(globalConfig) - - // Check if email validation is enabled - ev := newEmailValidator(globalConfig.Mailer) - if ev.isEnabled() { - mc = &emailValidatorMailClient{ev: ev, mc: mc} - } - - // Check if background emails are enabled - if globalConfig.Mailer.EmailBackgroundSending { - mc = &backgroundMailClient{ - mc: mc, - } - } - return mc -} - -// newMailClient returns a new MailClient based on the given configuration. -func newMailClient(globalConfig *conf.GlobalConfiguration) MailClient { - if globalConfig.SMTP.Host == "" { - logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) - return &noopMailClient{ - EmailValidator: newEmailValidator(globalConfig.Mailer), - } - } - - from := globalConfig.SMTP.FromAddress() - u, _ := url.ParseRequestURI(globalConfig.API.ExternalURL) - return &MailmeMailer{ - Host: globalConfig.SMTP.Host, - Port: globalConfig.SMTP.Port, - User: globalConfig.SMTP.User, - Pass: globalConfig.SMTP.Pass, - LocalName: u.Hostname(), - From: from, - BaseURL: globalConfig.SiteURL, - Logger: logrus.StandardLogger(), - MailLogging: globalConfig.SMTP.LoggingEnabled, - } -} - -func withDefault(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -func getPath(filepath string, params *EmailParams) (*url.URL, error) { - path := &url.URL{} - if filepath != "" { - if p, err := url.Parse(filepath); err != nil { - return nil, err - } else { - path = p - } - } - if params != nil { - path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) - } - return path, nil -} - -// Task holds a mail pending delivery by the Handler. -type Task struct { - mc MailClient - - To string `json:"to"` - SubjectTemplate string `json:"subject_template"` - TemplateURL string `json:"template_url"` - DefaultTemplate string `json:"default_template"` - TemplateData map[string]any `json:"template_data"` - Headers map[string][]string `json:"headers"` - Typ string `json:"typ"` -} - -// Run implements the Type method of the apitask.Task interface by returning -// the "mailer." prefix followed by the mail type. -func (o *Task) Type() string { return fmt.Sprintf("mailer.%v", o.Typ) } - -// Run implements the Run method of the apitask.Task interface by attempting -// to send the mail using the underying mail client. -func (o *Task) Run(ctx context.Context) error { - return o.mc.Mail( - ctx, - o.To, - o.SubjectTemplate, - o.TemplateURL, - o.DefaultTemplate, - o.TemplateData, - o.Headers, - o.Typ) -} - -type backgroundMailClient struct { - mc MailClient -} - -// Mail implements mailer.MailClient interface by sending the call to the -// wrapped mail client to the background. -func (o *backgroundMailClient) Mail( - ctx context.Context, - to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]any, - headers map[string][]string, - typ string, -) error { - tk := &Task{ - mc: o.mc, - To: to, - SubjectTemplate: subjectTemplate, - TemplateURL: templateURL, - DefaultTemplate: defaultTemplate, - TemplateData: templateData, - Headers: headers, - Typ: typ, - } - return apitask.Run(ctx, tk) -} diff --git a/internal/mailer/mailme.go b/internal/mailer/mailmeclient/mailmeclient.go similarity index 66% rename from internal/mailer/mailme.go rename to internal/mailer/mailmeclient/mailmeclient.go index 7a62cf151..12d48e338 100644 --- a/internal/mailer/mailme.go +++ b/internal/mailer/mailmeclient/mailmeclient.go @@ -1,4 +1,6 @@ -package mailer +// Package mailmeclient provides an implementation of mailer.Client that uses +// gopkg.in/gomail.v2 to send via SMTP. +package mailmeclient import ( "bytes" @@ -8,6 +10,7 @@ import ( "io" "log" "net/http" + "net/url" "strings" "sync" "time" @@ -15,16 +18,17 @@ import ( "gopkg.in/gomail.v2" "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" ) -// TemplateRetries is the amount of time MailMe will try to fetch a URL before giving up -const TemplateRetries = 3 +// templateRetries is the amount of time MailMe will try to fetch a URL before giving up +const templateRetries = 3 -// TemplateExpiration is the time period that the template will be cached for -const TemplateExpiration = 10 * time.Second +// templateExpiration is the time period that the template will be cached for +const templateExpiration = 10 * time.Second -// MailmeMailer lets MailMe send templated mails -type MailmeMailer struct { +// Client lets MailMe send templated mails +type Client struct { From string Host string Port int @@ -33,26 +37,44 @@ type MailmeMailer struct { BaseURL string LocalName string FuncMap template.FuncMap - cache *TemplateCache Logger logrus.FieldLogger MailLogging bool + + cache *templateCache +} + +// New returns a new *Mailer based on the given configuration. +func New(globalConfig *conf.GlobalConfiguration) *Client { + from := globalConfig.SMTP.FromAddress() + u, _ := url.ParseRequestURI(globalConfig.API.ExternalURL) + return &Client{ + Host: globalConfig.SMTP.Host, + Port: globalConfig.SMTP.Port, + User: globalConfig.SMTP.User, + Pass: globalConfig.SMTP.Pass, + LocalName: u.Hostname(), + From: from, + BaseURL: globalConfig.SiteURL, + Logger: logrus.StandardLogger(), + MailLogging: globalConfig.SMTP.LoggingEnabled, + } } // Mail sends a templated mail. It will try to load the template from a URL, and // otherwise fall back to the default -func (m *MailmeMailer) Mail( +func (m *Client) Mail( ctx context.Context, to, subjectTemplate, templateURL, defaultTemplate string, - templateData map[string]interface{}, + templateData map[string]any, headers map[string][]string, typ string, ) error { if m.FuncMap == nil { - m.FuncMap = map[string]interface{}{} + m.FuncMap = map[string]any{} } if m.cache == nil { - m.cache = &TemplateCache{ - templates: map[string]*MailTemplate{}, + m.cache = &templateCache{ + templates: map[string]*mailTemplate{}, funcMap: m.FuncMap, logger: m.Logger, } @@ -69,7 +91,7 @@ func (m *MailmeMailer) Mail( return err } - body, err := m.MailBody(templateURL, defaultTemplate, templateData) + body, err := m.mailBody(templateURL, defaultTemplate, templateData) if err != nil { return err } @@ -109,37 +131,37 @@ func (m *MailmeMailer) Mail( return nil } -type MailTemplate struct { +type mailTemplate struct { tmp *template.Template expiresAt time.Time } -type TemplateCache struct { - templates map[string]*MailTemplate +type templateCache struct { + templates map[string]*mailTemplate mutex sync.Mutex funcMap template.FuncMap logger logrus.FieldLogger } -func (t *TemplateCache) Get(url string) (*template.Template, error) { +func (t *templateCache) Get(url string) (*template.Template, error) { cached, ok := t.templates[url] if ok && (cached.expiresAt.Before(time.Now())) { return cached.tmp, nil } - data, err := t.fetchTemplate(url, TemplateRetries) + data, err := t.fetchTemplate(url, templateRetries) if err != nil { return nil, err } - return t.Set(url, data, TemplateExpiration) + return t.Set(url, data, templateExpiration) } -func (t *TemplateCache) Set(key, value string, expirationTime time.Duration) (*template.Template, error) { +func (t *templateCache) Set(key, value string, expirationTime time.Duration) (*template.Template, error) { parsed, err := template.New(key).Funcs(t.funcMap).Parse(value) if err != nil { return nil, err } - cached := &MailTemplate{ + cached := &mailTemplate{ tmp: parsed, expiresAt: time.Now().Add(expirationTime), } @@ -149,7 +171,7 @@ func (t *TemplateCache) Set(key, value string, expirationTime time.Duration) (*t return parsed, nil } -func (t *TemplateCache) fetchTemplate(url string, triesLeft int) (string, error) { +func (t *templateCache) fetchTemplate(url string, triesLeft int) (string, error) { client := &http.Client{ Timeout: 10 * time.Second, } @@ -178,12 +200,12 @@ func (t *TemplateCache) fetchTemplate(url string, triesLeft int) (string, error) return "", errors.New("mailer: unable to fetch mail template") } -func (m *MailmeMailer) MailBody(url string, defaultTemplate string, data map[string]interface{}) (string, error) { +func (m *Client) mailBody(url string, defaultTemplate string, data map[string]any) (string, error) { if m.FuncMap == nil { - m.FuncMap = map[string]interface{}{} + m.FuncMap = map[string]any{} } if m.cache == nil { - m.cache = &TemplateCache{templates: map[string]*MailTemplate{}, funcMap: m.FuncMap} + m.cache = &templateCache{templates: map[string]*mailTemplate{}, funcMap: m.FuncMap} } var temp *template.Template diff --git a/internal/mailer/noop.go b/internal/mailer/noopclient/noopclient.go similarity index 56% rename from internal/mailer/noop.go rename to internal/mailer/noopclient/noopclient.go index 1179df89b..927cf978d 100644 --- a/internal/mailer/noop.go +++ b/internal/mailer/noopclient/noopclient.go @@ -1,4 +1,6 @@ -package mailer +// Package noopclient provides an implementation of mailer.Client that simply +// does nothing. +package noopclient import ( "context" @@ -6,15 +8,18 @@ import ( "time" ) -type noopMailClient struct { - EmailValidator *EmailValidator - Delay time.Duration +type Client struct { + Delay time.Duration } -func (m *noopMailClient) Mail( +func New() *Client { + return &Client{} +} + +func (m *Client) Mail( ctx context.Context, to, subjectTemplate, templateURL, defaultTemplate string, - templateData map[string]interface{}, + templateData map[string]any, headers map[string][]string, typ string, ) error { @@ -29,11 +34,5 @@ func (m *noopMailClient) Mail( return ctx.Err() } } - - if m.EmailValidator != nil { - if err := m.EmailValidator.Validate(ctx, to); err != nil { - return err - } - } return nil } diff --git a/internal/mailer/taskclient/taskclient.go b/internal/mailer/taskclient/taskclient.go new file mode 100644 index 000000000..b8e63ef39 --- /dev/null +++ b/internal/mailer/taskclient/taskclient.go @@ -0,0 +1,84 @@ +// Package taskclient provides an implementation of mailer.Client that uses +// the apitask package to send mail in the background. +package taskclient + +import ( + "context" + "fmt" + + "github.com/supabase/auth/internal/api/apitask" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" +) + +// New will return a Client that runs a task in the background that will later +// call the given Client. If the mailer config EmailBackgroundSending is +// disabled it will return the same Client passed in mc. +func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client) mailer.Client { + + // Check if background emails are enabled + if globalConfig.Mailer.EmailBackgroundSending { + mc = &backgroundMailClient{mc: mc} + } + return mc +} + +// Task holds a mail pending delivery by the Handler. +type Task struct { + mc mailer.Client + + To string `json:"to"` + SubjectTemplate string `json:"subject_template"` + TemplateURL string `json:"template_url"` + DefaultTemplate string `json:"default_template"` + TemplateData map[string]any `json:"template_data"` + Headers map[string][]string `json:"headers"` + Typ string `json:"typ"` +} + +// Run implements the Type method of the apitask.Task interface by returning +// the "mailer." prefix followed by the mail type. +func (o *Task) Type() string { return fmt.Sprintf("mailer.%v", o.Typ) } + +// Run implements the Run method of the apitask.Task interface by attempting +// to send the mail using the underying mail client. +func (o *Task) Run(ctx context.Context) error { + return o.mc.Mail( + ctx, + o.To, + o.SubjectTemplate, + o.TemplateURL, + o.DefaultTemplate, + o.TemplateData, + o.Headers, + o.Typ) +} + +type backgroundMailClient struct { + mc mailer.Client +} + +// Mail implements mailer.MailClient interface by sending the call to the +// wrapped mail client to the background. +func (o *backgroundMailClient) Mail( + ctx context.Context, + to string, + subjectTemplate string, + templateURL string, + defaultTemplate string, + templateData map[string]any, + headers map[string][]string, + typ string, +) error { + tk := &Task{ + mc: o.mc, + To: to, + SubjectTemplate: subjectTemplate, + TemplateURL: templateURL, + DefaultTemplate: defaultTemplate, + TemplateData: templateData, + Headers: headers, + Typ: typ, + } + return apitask.Run(ctx, tk) +} diff --git a/internal/mailer/template.go b/internal/mailer/templatemailer/templatemailer.go similarity index 83% rename from internal/mailer/template.go rename to internal/mailer/templatemailer/templatemailer.go index b6b800100..9a9eb78b8 100644 --- a/internal/mailer/template.go +++ b/internal/mailer/templatemailer/templatemailer.go @@ -1,4 +1,4 @@ -package mailer +package templatemailer import ( "context" @@ -8,37 +8,45 @@ import ( "strings" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/models" ) -type MailRequest struct { - To string - SubjectTemplate string - TemplateURL string - DefaultTemplate string - TemplateData map[string]interface{} - Headers map[string][]string - Type string -} - -type MailClient interface { - Mail( - ctx context.Context, - to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]interface{}, - headers map[string][]string, - typ string, - ) error +// New will return a *TemplateMailer backed by the given mailer.Client. +func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client) *TemplateMailer { + return &TemplateMailer{ + SiteURL: globalConfig.SiteURL, + Config: globalConfig, + mc: mc, + } } // TemplateMailer will send mail and use templates from the site for easy mail styling type TemplateMailer struct { - SiteURL string - Config *conf.GlobalConfiguration - MailClient MailClient + SiteURL string + Config *conf.GlobalConfiguration + mc mailer.Client +} + +type emailParams struct { + Token string + Type string + RedirectTo string +} + +func getPath(filepath string, params *emailParams) (*url.URL, error) { + path := &url.URL{} + if filepath != "" { + if p, err := url.Parse(filepath); err != nil { + return nil, err + } else { + path = p + } + } + if params != nil { + path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) + } + return path, nil } func encodeRedirectURL(referrerURL string) string { @@ -53,17 +61,12 @@ func encodeRedirectURL(referrerURL string) string { return referrerURL } -const ( - SignupVerification = "signup" - RecoveryVerification = "recovery" - InviteVerification = "invite" - MagicLinkVerification = "magiclink" - EmailChangeVerification = "email_change" - EmailOTPVerification = "email" - EmailChangeCurrentVerification = "email_change_current" - EmailChangeNewVerification = "email_change_new" - ReauthenticationVerification = "reauthentication" -) +func withDefault(value, defaultValue string) string { + if value == "" { + return defaultValue + } + return value +} const defaultInviteMail = `

You have been invited

@@ -137,7 +140,7 @@ func (m *TemplateMailer) Headers(messageType string) map[string][]string { // InviteMail sends a invite mail to a new user func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ + path, err := getPath(m.Config.Mailer.URLPaths.Invite, &emailParams{ Token: user.ConfirmationToken, Type: "invite", RedirectTo: referrerURL, @@ -147,7 +150,7 @@ func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, ref return err } - data := map[string]interface{}{ + data := map[string]any{ "SiteURL": m.Config.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, @@ -157,7 +160,7 @@ func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, ref "RedirectTo": referrerURL, } - return m.MailClient.Mail( + return m.mc.Mail( r.Context(), user.GetEmail(), withDefault(m.Config.Mailer.Subjects.Invite, "You have been invited"), @@ -171,7 +174,7 @@ func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, ref // ConfirmationMail sends a signup confirmation mail to a new user func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ + path, err := getPath(m.Config.Mailer.URLPaths.Confirmation, &emailParams{ Token: user.ConfirmationToken, Type: "signup", RedirectTo: referrerURL, @@ -180,7 +183,7 @@ func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, ot return err } - data := map[string]interface{}{ + data := map[string]any{ "SiteURL": m.Config.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, @@ -190,7 +193,7 @@ func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, ot "RedirectTo": referrerURL, } - return m.MailClient.Mail( + return m.mc.Mail( r.Context(), user.GetEmail(), withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Your Email"), @@ -204,14 +207,14 @@ func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, ot // ReauthenticateMail sends a reauthentication mail to an authenticated user func (m *TemplateMailer) ReauthenticateMail(r *http.Request, user *models.User, otp string) error { - data := map[string]interface{}{ + data := map[string]any{ "SiteURL": m.Config.SiteURL, "Email": user.Email, "Token": otp, "Data": user.UserMetaData, } - return m.MailClient.Mail( + return m.mc.Mail( r.Context(), user.GetEmail(), withDefault(m.Config.Mailer.Subjects.Reauthentication, "Confirm reauthentication"), @@ -260,7 +263,7 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp for _, email := range emails { path, err := getPath( m.Config.Mailer.URLPaths.EmailChange, - &EmailParams{ + &emailParams{ Token: email.TokenHash, Type: "email_change", RedirectTo: referrerURL, @@ -270,7 +273,7 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp return err } go func(address, token, tokenHash, template string) { - data := map[string]interface{}{ + data := map[string]any{ "SiteURL": m.Config.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.GetEmail(), @@ -281,7 +284,7 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp "Data": user.UserMetaData, "RedirectTo": referrerURL, } - errors <- m.MailClient.Mail( + errors <- m.mc.Mail( ctx, address, withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), @@ -305,7 +308,7 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp // RecoveryMail sends a password recovery mail func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "recovery", RedirectTo: referrerURL, @@ -313,7 +316,7 @@ func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, r if err != nil { return err } - data := map[string]interface{}{ + data := map[string]any{ "SiteURL": m.Config.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, @@ -323,7 +326,7 @@ func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, r "RedirectTo": referrerURL, } - return m.MailClient.Mail( + return m.mc.Mail( r.Context(), user.GetEmail(), withDefault(m.Config.Mailer.Subjects.Recovery, "Reset Your Password"), @@ -337,7 +340,7 @@ func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, r // MagicLinkMail sends a login link mail func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "magiclink", RedirectTo: referrerURL, @@ -346,7 +349,7 @@ func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, return err } - data := map[string]interface{}{ + data := map[string]any{ "SiteURL": m.Config.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, @@ -356,7 +359,7 @@ func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, "RedirectTo": referrerURL, } - return m.MailClient.Mail( + return m.mc.Mail( r.Context(), user.GetEmail(), withDefault(m.Config.Mailer.Subjects.MagicLink, "Your Magic Link"), @@ -375,37 +378,37 @@ func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referr switch actionType { case "magiclink": - path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "magiclink", RedirectTo: referrerURL, }) case "recovery": - path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "recovery", RedirectTo: referrerURL, }) case "invite": - path, err = getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ + path, err = getPath(m.Config.Mailer.URLPaths.Invite, &emailParams{ Token: user.ConfirmationToken, Type: "invite", RedirectTo: referrerURL, }) case "signup": - path, err = getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ + path, err = getPath(m.Config.Mailer.URLPaths.Confirmation, &emailParams{ Token: user.ConfirmationToken, Type: "signup", RedirectTo: referrerURL, }) case "email_change_current": - path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &emailParams{ Token: user.EmailChangeTokenCurrent, Type: "email_change", RedirectTo: referrerURL, }) case "email_change_new": - path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &emailParams{ Token: user.EmailChangeTokenNew, Type: "email_change", RedirectTo: referrerURL, diff --git a/internal/mailer/template_test.go b/internal/mailer/templatemailer/templatemailer_test.go similarity index 98% rename from internal/mailer/template_test.go rename to internal/mailer/templatemailer/templatemailer_test.go index f8fcd7417..22a1c5cf5 100644 --- a/internal/mailer/template_test.go +++ b/internal/mailer/templatemailer/templatemailer_test.go @@ -1,4 +1,4 @@ -package mailer +package templatemailer import ( "testing" diff --git a/internal/mailer/mailer_test.go b/internal/mailer/templatemailer/templatemailer_url_test.go similarity index 96% rename from internal/mailer/mailer_test.go rename to internal/mailer/templatemailer/templatemailer_url_test.go index 290d65dd0..672d37130 100644 --- a/internal/mailer/mailer_test.go +++ b/internal/mailer/templatemailer/templatemailer_url_test.go @@ -1,4 +1,4 @@ -package mailer +package templatemailer import ( "net/url" @@ -15,7 +15,7 @@ func enforceRelativeURL(url string) string { } func TestGetPath(t *testing.T) { - params := EmailParams{ + params := emailParams{ Token: "token", Type: "signup", RedirectTo: "https://example.com", @@ -23,7 +23,7 @@ func TestGetPath(t *testing.T) { cases := []struct { SiteURL string Path string - Params *EmailParams + Params *emailParams Expected string }{ { diff --git a/internal/mailer/validate.go b/internal/mailer/validateclient/validateclient.go similarity index 78% rename from internal/mailer/validate.go rename to internal/mailer/validateclient/validateclient.go index b8231f26e..6a0df04cc 100644 --- a/internal/mailer/validate.go +++ b/internal/mailer/validateclient/validateclient.go @@ -1,4 +1,4 @@ -package mailer +package validateclient import ( "bytes" @@ -13,6 +13,7 @@ import ( "time" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" "golang.org/x/sync/errgroup" ) @@ -77,15 +78,79 @@ var ( ErrInvalidEmailMX = errors.New("invalid_email_mx") ) -type EmailValidator struct { +// New will return a Client that first calls an email validator before passing +// the mail along to given Client. If email validation is disabled then it +// returns the same Client passed in mc. +func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client) mailer.Client { + + // Check if email validation is enabled + ev := newEmailValidator(globalConfig.Mailer) + if ev.isEnabled() { + mc = &emailValidatorMailClient{ev: ev, mc: mc} + } + return mc +} + +type emailValidatorMailClient struct { + ev *emailValidator + mc mailer.Client +} + +// Mail implements mailer.MailClient interface by calling validate before +// passing the mail request to the next MailClient. +func (o *emailValidatorMailClient) Mail( + ctx context.Context, + to string, + subjectTemplate string, + templateURL string, + defaultTemplate string, + templateData map[string]any, + headers map[string][]string, + typ string, +) error { + if err := o.ev.Validate(ctx, to); err != nil { + return err + } + return o.mc.Mail( + ctx, + to, + subjectTemplate, + templateURL, + defaultTemplate, + templateData, + headers, + typ, + ) +} + +type emailValidator struct { extended bool serviceURL string serviceHeaders map[string][]string blockedMXRecords map[string]bool } -func newEmailValidator(mc conf.MailerConfiguration) *EmailValidator { - return &EmailValidator{ +func (m *emailValidator) MailNew( + ctx context.Context, + to, subject, body string, + headers map[string][]string, + typ string, +) error { + return nil +} + +func (m *emailValidator) Mail( + ctx context.Context, + to, subjectTemplate, templateURL, defaultTemplate string, + templateData map[string]any, + headers map[string][]string, + typ string, +) error { + return nil +} + +func newEmailValidator(mc conf.MailerConfiguration) *emailValidator { + return &emailValidator{ extended: mc.EmailValidationExtended, serviceURL: mc.EmailValidationServiceURL, serviceHeaders: mc.GetEmailValidationServiceHeaders(), @@ -93,12 +158,12 @@ func newEmailValidator(mc conf.MailerConfiguration) *EmailValidator { } } -func (ev *EmailValidator) isEnabled() bool { +func (ev *emailValidator) isEnabled() bool { return ev.isExtendedEnabled() || ev.isServiceEnabled() } -func (ev *EmailValidator) isExtendedEnabled() bool { return ev.extended } -func (ev *EmailValidator) isServiceEnabled() bool { return ev.serviceURL != "" } +func (ev *emailValidator) isExtendedEnabled() bool { return ev.extended } +func (ev *emailValidator) isServiceEnabled() bool { return ev.serviceURL != "" } // Validate performs validation on the given email. // @@ -108,7 +173,7 @@ func (ev *EmailValidator) isServiceEnabled() bool { return ev.serviceURL != "" // // When serviceURL AND serviceKey are non-empty strings it uses the remote // service to determine if the email is valid. -func (ev *EmailValidator) Validate(ctx context.Context, email string) error { +func (ev *emailValidator) Validate(ctx context.Context, email string) error { if !ev.isEnabled() { return nil } @@ -151,7 +216,7 @@ func (ev *EmailValidator) Validate(ctx context.Context, email string) error { // validateStatic will validate the format and do the static checks before // returning the host portion of the email. -func (ev *EmailValidator) validateStatic(email string) (string, error) { +func (ev *emailValidator) validateStatic(email string) (string, error) { if !ev.isExtendedEnabled() { return "", nil } @@ -189,7 +254,7 @@ func (ev *EmailValidator) validateStatic(email string) (string, error) { return host, nil } -func (ev *EmailValidator) validateService(ctx context.Context, email string) error { +func (ev *emailValidator) validateService(ctx context.Context, email string) error { if !ev.isServiceEnabled() { return nil } @@ -244,7 +309,7 @@ func (ev *EmailValidator) validateService(ctx context.Context, email string) err return ErrInvalidEmailAddress } -func (ev *EmailValidator) validateProviders(name, host string) error { +func (ev *emailValidator) validateProviders(name, host string) error { switch host { case "gmail.com": // Based on a sample of internal data, this reduces the number of @@ -259,7 +324,7 @@ func (ev *EmailValidator) validateProviders(name, host string) error { return nil } -func (ev *EmailValidator) validateHost(ctx context.Context, host string) error { +func (ev *emailValidator) validateHost(ctx context.Context, host string) error { mxs, err := validateEmailResolver.LookupMX(ctx, host) if !isHostNotFound(err) { return ev.validateMXRecords(mxs, nil) @@ -274,7 +339,7 @@ func (ev *EmailValidator) validateHost(ctx context.Context, host string) error { return ErrInvalidEmailDNS } -func (ev *EmailValidator) validateMXRecords(mxs []*net.MX, hosts []string) error { +func (ev *emailValidator) validateMXRecords(mxs []*net.MX, hosts []string) error { for _, mx := range mxs { if ev.blockedMXRecords[mx.Host] { return ErrInvalidEmailMX diff --git a/internal/mailer/validate_test.go b/internal/mailer/validateclient/validateclient_test.go similarity index 99% rename from internal/mailer/validate_test.go rename to internal/mailer/validateclient/validateclient_test.go index e9bae8fef..43e2fbc11 100644 --- a/internal/mailer/validate_test.go +++ b/internal/mailer/validateclient/validateclient_test.go @@ -1,4 +1,4 @@ -package mailer +package validateclient import ( "context" From f7c80a76cf70ee22100bc71c5a2faeb565b92cfc Mon Sep 17 00:00:00 2001 From: Chris Stockton Date: Thu, 28 Aug 2025 16:20:09 -0700 Subject: [PATCH 2/3] chore: remove duplicate import due to mailer -> mail alias --- internal/api/mail.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/api/mail.go b/internal/api/mail.go index 96ed8b9ba..963ad9f85 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -9,7 +9,6 @@ import ( "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/hooks/v0hooks" - "github.com/supabase/auth/internal/mailer" mail "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/mailer/mailmeclient" "github.com/supabase/auth/internal/mailer/noopclient" @@ -33,7 +32,7 @@ import ( // newMailer returns a new gotrue mailer func newMailer(globalConfig *conf.GlobalConfiguration) *templatemailer.TemplateMailer { - var mc mailer.Client + var mc mail.Client if globalConfig.SMTP.Host == "" { logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) mc = noopclient.New() From 48d98cec8a71c5b03fe7550d0a6b2ee66661644d Mon Sep 17 00:00:00 2001 From: Chris Stockton <180184+cstockton@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:15:44 -0700 Subject: [PATCH 3/3] feat: background template reloading p2 - mailer template cache & interface refactor (#2150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary** This PR builds on [auth#2148](https://github.com/supabase/auth/pull/2148) to deliver phase 2, a robust template cache plus a simplified mail client interface. Templates are now rendered once per type, cached with controlled refresh, and then delivered via clients that only send pre-rendered subject/body. This reduces latency and removes templating logic from SMTP clients while preserving current behavior. It will also pave the way for alternative Mailer implementations. **Template cache + interface refactor** * Replace `TemplateMailer` with `templatemailer.Mailer` that owns all templating and a shared cache. Entries have a TTL and a shorter “recheck interval”; on fetch/parse errors we serve the last good or fall back to built-in defaults. * Centralize default subjects/bodies and validate them at init. Typed template keys are: `invite`, `confirmation, recovery`, `email_change`, `magic_link`, `reauthentication`. * Move headers logic into the mailer; keep `$messageType` substitution and the legacy “reauthenticate” header name for compatibility. * Change `mailer.Client.Mail` to `(ctx, to, subject, body, headers, typ)` and update all clients accordingly (SMTP/Noop/Validate/Tasks). * Persist a single mailer instance in `API`; add `WithMailer` for DI; remove on-demand construction. --------- Co-authored-by: Chris Stockton --- cmd/serve_cmd.go | 62 +- internal/api/api.go | 12 +- internal/api/apiworker/apiworker.go | 111 ++++ internal/api/mail.go | 26 - internal/api/options.go | 30 +- internal/conf/configuration.go | 20 +- internal/mailer/mailer.go | 7 +- internal/mailer/mailmeclient/mailmeclient.go | 180 +----- internal/mailer/noopclient/noopclient.go | 5 +- internal/mailer/taskclient/taskclient.go | 38 +- internal/mailer/templatemailer/template.go | 601 ++++++++++++++++++ .../mailer/templatemailer/templatemailer.go | 260 ++++---- .../templatemailer/templatemailer_test.go | 15 +- .../mailer/validateclient/validateclient.go | 12 +- 14 files changed, 957 insertions(+), 422 deletions(-) create mode 100644 internal/api/apiworker/apiworker.go create mode 100644 internal/mailer/templatemailer/template.go diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index bc841e849..6b2dee94c 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -14,7 +14,9 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/api/apiworker" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer/templatemailer" "github.com/supabase/auth/internal/reloader" "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/utilities" @@ -48,22 +50,24 @@ func serve(ctx context.Context) { } defer db.Close() - addr := net.JoinHostPort(config.API.Host, config.API.Port) - - opts := []api.Option{ - api.NewLimiterOptions(config), - } - baseCtx, baseCancel := context.WithCancel(context.Background()) defer baseCancel() var wg sync.WaitGroup defer wg.Wait() // Do not return to caller until this goroutine is done. - a := api.NewAPIWithVersion(config, db, utilities.Version, opts...) - ah := reloader.NewAtomicHandler(a) - logrus.WithField("version", a.Version()).Infof("GoTrue API started on: %s", addr) + mrCache := templatemailer.NewCache() + limiterOpts := api.NewLimiterOptions(config) + initialAPI := api.NewAPIWithVersion( + config, db, utilities.Version, + limiterOpts, + api.WithMailer(templatemailer.FromConfig(config, mrCache)), + ) + addr := net.JoinHostPort(config.API.Host, config.API.Port) + logrus.WithField("version", initialAPI.Version()).Infof("GoTrue API started on: %s", addr) + + ah := reloader.NewAtomicHandler(initialAPI) httpSrv := &http.Server{ Addr: addr, Handler: ah, @@ -74,6 +78,26 @@ func serve(ctx context.Context) { } log := logrus.WithField("component", "api") + wrkLog := logrus.WithField("component", "apiworker") + wrk := apiworker.New(config, mrCache, wrkLog) + wg.Add(1) + go func() { + defer wg.Done() + + var err error + defer func() { + logFn := wrkLog.Info + if err != nil { + logFn = wrkLog.WithError(err).Error + } + logFn("background apiworker is exiting") + }() + + // Work exits when ctx is done as in-flight requests do not depend + // on it. If they do in the future this should be baseCtx instead. + err = wrk.Work(ctx) + }() + if watchDir != "" { wg.Add(1) go func() { @@ -81,8 +105,26 @@ func serve(ctx context.Context) { fn := func(latestCfg *conf.GlobalConfiguration) { log.Info("reloading api with new configuration") + + // When config is updated we notify the apiworker. + wrk.ReloadConfig(latestCfg) + + // Create a new API version with the updated config. latestAPI := api.NewAPIWithVersion( - latestCfg, db, utilities.Version, opts...) + config, db, utilities.Version, + + // Create a new mailer with existing template cache. + api.WithMailer( + templatemailer.FromConfig(config, mrCache), + ), + + // Persist existing rate limiters. + // + // TODO(cstockton): we should consider updating these, if we + // rely on hot config reloads 100% then rate limiter changes + // won't be picked up. + limiterOpts, + ) ah.Store(latestAPI) } diff --git a/internal/api/api.go b/internal/api/api.go index c4fb7b614..19b93eda0 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -16,6 +16,7 @@ import ( "github.com/supabase/auth/internal/hooks/hookspgfunc" "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/mailer/templatemailer" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" @@ -42,6 +43,7 @@ type API struct { hibpClient *hibp.PwnedClient oauthServer *oauthserver.Server tokenService *tokens.Service + mailer mailer.Mailer // overrideTime can be used to override the clock used by handlers. Should only be used in tests! overrideTime func() time.Time @@ -52,6 +54,7 @@ type API struct { func (a *API) GetConfig() *conf.GlobalConfiguration { return a.config } func (a *API) GetDB() *storage.Connection { return a.db } func (a *API) GetTokenService() *tokens.Service { return a.tokenService } +func (a *API) Mailer() mailer.Mailer { return a.mailer } func (a *API) Version() string { return a.version @@ -109,6 +112,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if api.tokenService == nil { api.tokenService = tokens.NewService(globalConfig, api.hooksMgr) } + if api.mailer == nil { + tc := templatemailer.NewCache() + api.mailer = templatemailer.FromConfig(globalConfig, tc) + } // Connect token service to API's time function (supports test overrides) api.tokenService.SetTimeFunc(api.Now) @@ -368,11 +375,6 @@ func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error { }) } -// Mailer returns NewMailer with the current tenant config -func (a *API) Mailer() mailer.Mailer { - return newMailer(a.config) -} - // ServeHTTP implements the http.Handler interface by passing the request along // to its underlying Handler. func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/internal/api/apiworker/apiworker.go b/internal/api/apiworker/apiworker.go new file mode 100644 index 000000000..e6627c9f4 --- /dev/null +++ b/internal/api/apiworker/apiworker.go @@ -0,0 +1,111 @@ +package apiworker + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer/templatemailer" +) + +// Worker is a simple background worker for async tasks. +type Worker struct { + le *logrus.Entry + tc *templatemailer.Cache + + // Notifies worker the cfg has been updated. + cfgCh chan struct{} + + // workMu must be held for calls to Work + workMu sync.Mutex + + // mu must be held for field access below here + mu sync.Mutex + cfg *conf.GlobalConfiguration +} + +// New will return a new *Worker instance. +func New( + cfg *conf.GlobalConfiguration, + tc *templatemailer.Cache, + le *logrus.Entry, +) *Worker { + return &Worker{ + le: le, + cfg: cfg, + tc: tc, + cfgCh: make(chan struct{}, 1), + } +} + +func (o *Worker) putConfig(cfg *conf.GlobalConfiguration) { + o.mu.Lock() + defer o.mu.Unlock() + o.cfg = cfg +} + +func (o *Worker) getConfig() *conf.GlobalConfiguration { + o.mu.Lock() + defer o.mu.Unlock() + return o.cfg +} + +// ReloadConfig notifies the worker a new configuration is available. +func (o *Worker) ReloadConfig(cfg *conf.GlobalConfiguration) { + o.putConfig(cfg) + + select { + case o.cfgCh <- struct{}{}: + default: + } +} + +// Work will periodically reload the templates in the background as long as the +// system remains active. +func (o *Worker) Work(ctx context.Context) error { + if ok := o.workMu.TryLock(); !ok { + return errors.New("apiworker: concurrent calls to Work are invalid") + } + defer o.workMu.Unlock() + + le := o.le.WithFields(logrus.Fields{ + "worker_type": "apiworker_template_cache", + }) + le.Info("apiworker: template cache worker started") + defer le.Info("apiworker: template cache worker exited") + + // Reload templates right away on Work. + o.maybeReloadTemplates(ctx, o.getConfig()) + + ival := func() time.Duration { + return max(time.Second, o.getConfig().Mailer.TemplateRetryInterval/4) + } + + tr := time.NewTicker(ival()) + defer tr.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-o.cfgCh: + tr.Reset(ival()) + case <-tr.C: + } + + // Either ticker fired or we got a config update. + o.maybeReloadTemplates(ctx, o.getConfig()) + } +} + +func (o *Worker) maybeReloadTemplates( + ctx context.Context, + cfg *conf.GlobalConfiguration, +) { + if cfg.Mailer.TemplateReloadingEnabled { + o.tc.Reload(ctx, cfg) + } +} diff --git a/internal/api/mail.go b/internal/api/mail.go index 963ad9f85..cfbe57397 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -6,14 +6,8 @@ import ( "strings" "time" - "github.com/sirupsen/logrus" - "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/hooks/v0hooks" mail "github.com/supabase/auth/internal/mailer" - "github.com/supabase/auth/internal/mailer/mailmeclient" - "github.com/supabase/auth/internal/mailer/noopclient" - "github.com/supabase/auth/internal/mailer/taskclient" - "github.com/supabase/auth/internal/mailer/templatemailer" "github.com/supabase/auth/internal/mailer/validateclient" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -30,26 +24,6 @@ import ( "github.com/supabase/auth/internal/utilities" ) -// newMailer returns a new gotrue mailer -func newMailer(globalConfig *conf.GlobalConfiguration) *templatemailer.TemplateMailer { - var mc mail.Client - if globalConfig.SMTP.Host == "" { - logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) - mc = noopclient.New() - } else { - mc = mailmeclient.New(globalConfig) - } - - // Wrap client with validation first - mc = validateclient.New(globalConfig, mc) - - // Then background tasks - mc = taskclient.New(globalConfig, mc) - - // Finally the template mailer - return templatemailer.New(globalConfig, mc) -} - var ( EmailRateLimitExceeded error = errors.New("email rate limit exceeded") ) diff --git a/internal/api/options.go b/internal/api/options.go index a54efe36e..d47aba276 100644 --- a/internal/api/options.go +++ b/internal/api/options.go @@ -6,6 +6,7 @@ import ( "github.com/didip/tollbooth/v5" "github.com/didip/tollbooth/v5/limiter" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/ratelimit" "github.com/supabase/auth/internal/tokens" ) @@ -14,6 +15,22 @@ type Option interface { apply(*API) } +type optionFunc func(*API) + +func (f optionFunc) apply(a *API) { f(a) } + +func WithMailer(m mailer.Mailer) Option { + return optionFunc(func(a *API) { + a.mailer = m + }) +} + +func WithTokenService(service *tokens.Service) Option { + return optionFunc(func(a *API) { + a.tokenService = service + }) +} + type LimiterOptions struct { Email ratelimit.Limiter Phone ratelimit.Limiter @@ -37,19 +54,6 @@ type LimiterOptions struct { func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } -// TokenServiceOption allows injecting a custom token service -type TokenServiceOption struct { - service *tokens.Service -} - -func WithTokenService(service *tokens.Service) *TokenServiceOption { - return &TokenServiceOption{service: service} -} - -func (tso *TokenServiceOption) apply(a *API) { - a.tokenService = tso.service -} - func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { o := &LimiterOptions{} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 02f5aeaec..e719ef4c6 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -444,13 +444,31 @@ type MailerConfiguration struct { ExternalHosts []string `json:"external_hosts" split_words:"true"` - // EXPERIMENTAL: May be removed in a future release. + // EXPERIMENTAL: All config below here may be removed in a future release. EmailBackgroundSending bool `json:"email_background_sending" split_words:"true" default:"false"` EmailValidationExtended bool `json:"email_validation_extended" split_words:"true" default:"false"` EmailValidationServiceURL string `json:"email_validation_service_url" split_words:"true"` EmailValidationServiceHeaders string `json:"email_validation_service_headers" split_words:"true"` EmailValidationBlockedMX string `json:"email_validation_blocked_mx" split_words:"true"` + // Max size in bytes we will read from a template endpoint + TemplateMaxSize int `json:"template_max_size" split_words:"true" default:"1000000"` + + // The maximum age of a template before we consider it stale. + TemplateMaxAge time.Duration `json:"template_max_age" split_words:"true" default:"10m"` + + // The time between retrying a failed template reload. + TemplateRetryInterval time.Duration `json:"template_retry_interval" split_words:"true" default:"10s"` + + // If true enable background reloading of templates to avoid blocking + // IO in requests. + TemplateReloadingEnabled bool `json:"template_reloading_enabled" split_words:"true" default:"false"` + + // The maximum time a server may be idle before template reloading stops. + // Note that even when the server is idle, a config reload will trigger a + // template reload. + TemplateReloadingMaxIdle time.Duration `json:"template_reloading_max_idle" split_words:"true" default:"20m"` + serviceHeaders map[string][]string `json:"-"` blockedMXRecords map[string]bool `json:"-"` } diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go index 0b6ad383a..5c8602578 100644 --- a/internal/mailer/mailer.go +++ b/internal/mailer/mailer.go @@ -31,14 +31,13 @@ type Mailer interface { GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) } +// TODO(cstockton): Mail(...) -> Mail(Email{...}) ? type Client interface { Mail( ctx context.Context, to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]any, + subject string, + body string, headers map[string][]string, typ string, ) error diff --git a/internal/mailer/mailmeclient/mailmeclient.go b/internal/mailer/mailmeclient/mailmeclient.go index 12d48e338..0062a04ac 100644 --- a/internal/mailer/mailmeclient/mailmeclient.go +++ b/internal/mailer/mailmeclient/mailmeclient.go @@ -3,17 +3,8 @@ package mailmeclient import ( - "bytes" "context" - "errors" - "html/template" - "io" - "log" - "net/http" "net/url" - "strings" - "sync" - "time" "gopkg.in/gomail.v2" @@ -21,26 +12,17 @@ import ( "github.com/supabase/auth/internal/conf" ) -// templateRetries is the amount of time MailMe will try to fetch a URL before giving up -const templateRetries = 3 - -// templateExpiration is the time period that the template will be cached for -const templateExpiration = 10 * time.Second - // Client lets MailMe send templated mails type Client struct { - From string - Host string - Port int - User string - Pass string - BaseURL string - LocalName string - FuncMap template.FuncMap + From string + Host string + Port int + User string + Pass string + LocalName string + Logger logrus.FieldLogger MailLogging bool - - cache *templateCache } // New returns a new *Mailer based on the given configuration. @@ -54,7 +36,6 @@ func New(globalConfig *conf.GlobalConfiguration) *Client { Pass: globalConfig.SMTP.Pass, LocalName: u.Hostname(), From: from, - BaseURL: globalConfig.SiteURL, Logger: logrus.StandardLogger(), MailLogging: globalConfig.SMTP.LoggingEnabled, } @@ -64,42 +45,16 @@ func New(globalConfig *conf.GlobalConfiguration) *Client { // otherwise fall back to the default func (m *Client) Mail( ctx context.Context, - to, subjectTemplate, templateURL, defaultTemplate string, - templateData map[string]any, + to string, + subject string, + body string, headers map[string][]string, typ string, ) error { - if m.FuncMap == nil { - m.FuncMap = map[string]any{} - } - if m.cache == nil { - m.cache = &templateCache{ - templates: map[string]*mailTemplate{}, - funcMap: m.FuncMap, - logger: m.Logger, - } - } - - tmp, err := template.New("Subject").Funcs(template.FuncMap(m.FuncMap)).Parse(subjectTemplate) - if err != nil { - return err - } - - subject := &bytes.Buffer{} - err = tmp.Execute(subject, templateData) - if err != nil { - return err - } - - body, err := m.mailBody(templateURL, defaultTemplate, templateData) - if err != nil { - return err - } - mail := gomail.NewMessage() mail.SetHeader("From", m.From) mail.SetHeader("To", to) - mail.SetHeader("Subject", subject.String()) + mail.SetHeader("Subject", subject) for k, v := range headers { if v != nil { @@ -130,116 +85,3 @@ func (m *Client) Mail( } return nil } - -type mailTemplate struct { - tmp *template.Template - expiresAt time.Time -} - -type templateCache struct { - templates map[string]*mailTemplate - mutex sync.Mutex - funcMap template.FuncMap - logger logrus.FieldLogger -} - -func (t *templateCache) Get(url string) (*template.Template, error) { - cached, ok := t.templates[url] - if ok && (cached.expiresAt.Before(time.Now())) { - return cached.tmp, nil - } - data, err := t.fetchTemplate(url, templateRetries) - if err != nil { - return nil, err - } - return t.Set(url, data, templateExpiration) -} - -func (t *templateCache) Set(key, value string, expirationTime time.Duration) (*template.Template, error) { - parsed, err := template.New(key).Funcs(t.funcMap).Parse(value) - if err != nil { - return nil, err - } - - cached := &mailTemplate{ - tmp: parsed, - expiresAt: time.Now().Add(expirationTime), - } - t.mutex.Lock() - t.templates[key] = cached - t.mutex.Unlock() - return parsed, nil -} - -func (t *templateCache) fetchTemplate(url string, triesLeft int) (string, error) { - client := &http.Client{ - Timeout: 10 * time.Second, - } - - resp, err := client.Get(url) - if err != nil && triesLeft > 0 { - return t.fetchTemplate(url, triesLeft-1) - } - if err != nil { - return "", err - } - defer resp.Body.Close() - if resp.StatusCode == 200 { // OK - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil && triesLeft > 0 { - return t.fetchTemplate(url, triesLeft-1) - } - if err != nil { - return "", err - } - return string(bodyBytes), err - } - if triesLeft > 0 { - return t.fetchTemplate(url, triesLeft-1) - } - return "", errors.New("mailer: unable to fetch mail template") -} - -func (m *Client) mailBody(url string, defaultTemplate string, data map[string]any) (string, error) { - if m.FuncMap == nil { - m.FuncMap = map[string]any{} - } - if m.cache == nil { - m.cache = &templateCache{templates: map[string]*mailTemplate{}, funcMap: m.FuncMap} - } - - var temp *template.Template - var err error - - if url != "" { - var absoluteURL string - if strings.HasPrefix(url, "http") { - absoluteURL = url - } else { - absoluteURL = m.BaseURL + url - } - temp, err = m.cache.Get(absoluteURL) - if err != nil { - log.Printf("Error loading template from %v: %v\n", url, err) - } - } - - if temp == nil { - cached, ok := m.cache.templates[url] - if ok { - temp = cached.tmp - } else { - temp, err = m.cache.Set(url, defaultTemplate, 0) - if err != nil { - return "", err - } - } - } - - buf := &bytes.Buffer{} - err = temp.Execute(buf, data) - if err != nil { - return "", err - } - return buf.String(), nil -} diff --git a/internal/mailer/noopclient/noopclient.go b/internal/mailer/noopclient/noopclient.go index 927cf978d..e40bab202 100644 --- a/internal/mailer/noopclient/noopclient.go +++ b/internal/mailer/noopclient/noopclient.go @@ -18,8 +18,9 @@ func New() *Client { func (m *Client) Mail( ctx context.Context, - to, subjectTemplate, templateURL, defaultTemplate string, - templateData map[string]any, + to string, + subject string, + body string, headers map[string][]string, typ string, ) error { diff --git a/internal/mailer/taskclient/taskclient.go b/internal/mailer/taskclient/taskclient.go index b8e63ef39..e966d3c16 100644 --- a/internal/mailer/taskclient/taskclient.go +++ b/internal/mailer/taskclient/taskclient.go @@ -27,13 +27,11 @@ func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client) mailer.Client type Task struct { mc mailer.Client - To string `json:"to"` - SubjectTemplate string `json:"subject_template"` - TemplateURL string `json:"template_url"` - DefaultTemplate string `json:"default_template"` - TemplateData map[string]any `json:"template_data"` - Headers map[string][]string `json:"headers"` - Typ string `json:"typ"` + To string `json:"to"` + Subject string `json:"subject"` + Body string `json:"body"` + Headers map[string][]string `json:"headers"` + Typ string `json:"typ"` } // Run implements the Type method of the apitask.Task interface by returning @@ -46,10 +44,8 @@ func (o *Task) Run(ctx context.Context) error { return o.mc.Mail( ctx, o.To, - o.SubjectTemplate, - o.TemplateURL, - o.DefaultTemplate, - o.TemplateData, + o.Subject, + o.Body, o.Headers, o.Typ) } @@ -63,22 +59,18 @@ type backgroundMailClient struct { func (o *backgroundMailClient) Mail( ctx context.Context, to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]any, + subject string, + body string, headers map[string][]string, typ string, ) error { tk := &Task{ - mc: o.mc, - To: to, - SubjectTemplate: subjectTemplate, - TemplateURL: templateURL, - DefaultTemplate: defaultTemplate, - TemplateData: templateData, - Headers: headers, - Typ: typ, + mc: o.mc, + To: to, + Subject: subject, + Body: body, + Headers: headers, + Typ: typ, } return apitask.Run(ctx, tk) } diff --git a/internal/mailer/templatemailer/template.go b/internal/mailer/templatemailer/template.go new file mode 100644 index 000000000..0263a7700 --- /dev/null +++ b/internal/mailer/templatemailer/template.go @@ -0,0 +1,601 @@ +package templatemailer + +import ( + "bytes" + "context" + "fmt" + "html/template" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/mailer/mailmeclient" + "github.com/supabase/auth/internal/mailer/noopclient" + "github.com/supabase/auth/internal/mailer/taskclient" + "github.com/supabase/auth/internal/mailer/validateclient" + "github.com/supabase/auth/internal/observability" + "golang.org/x/sync/singleflight" +) + +func init() { + // Ensure every TemplateType has a default subject & body. + if err := checkDefaults(); err != nil { + panic(err) + } +} + +// Mailer will send mail and use templates from the site for easy mail styling +type Mailer struct { + cfg *conf.GlobalConfiguration + mc mailer.Client + tc *Cache +} + +// FromConfig returns a new mailer configured using the global configuration. +func FromConfig(globalConfig *conf.GlobalConfiguration, tc *Cache) *Mailer { + var mc mailer.Client + if globalConfig.SMTP.Host == "" { + logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) + mc = noopclient.New() + } else { + mc = mailmeclient.New(globalConfig) + } + + // Wrap client with validation first + mc = validateclient.New(globalConfig, mc) + + // Then background tasks + mc = taskclient.New(globalConfig, mc) + + // Finally the template mailer + return New(globalConfig, mc, tc) +} + +// New will return a *TemplateMailer backed by the given mailer.Client. +func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client, tc *Cache) *Mailer { + return &Mailer{ + cfg: globalConfig, + mc: mc, + tc: tc, + } +} + +func (m *Mailer) mail( + ctx context.Context, + cfg *conf.GlobalConfiguration, + tpl string, + to string, + data map[string]any, +) error { + if _, ok := lookupEmailContentConfig(&cfg.Mailer.Subjects, tpl); !ok { + return fmt.Errorf("templatemailer: template type: %s is invalid", tpl) + } + + // This is to match the previous behavior, which sent a "reauthenticate" + // header instead of the same name as template. + typ := tpl + if typ == ReauthenticationTemplate { + typ = "reauthenticate" + } + headers := m.Headers(cfg, typ) + + ent, err := m.tc.get(ctx, cfg, tpl) + if err != nil { + return err + } + + var buf bytes.Buffer + subject, body, err := ent.execute(&buf, data) + if err != nil { + return err + } + return m.mc.Mail( + ctx, + to, + subject, + body, + headers, + typ, + ) +} + +type tplCacheEntry struct { + createdAt time.Time + checkedAt time.Time + def bool + typ string + subject *template.Template + body *template.Template +} + +func newTplCacheEntry( + at time.Time, + typ string, + subject, body *template.Template, +) *tplCacheEntry { + return &tplCacheEntry{ + createdAt: at, + checkedAt: at, + typ: typ, + subject: subject, + body: body, + } +} + +func (ent *tplCacheEntry) copy() *tplCacheEntry { + cpy := *ent + return &cpy +} + +func (ent *tplCacheEntry) execute( + buf *bytes.Buffer, + data map[string]any, +) (subject string, body string, err error) { + if err = ent.subject.Execute(buf, data); err != nil { + return "", "", err + } + subject = buf.String() + + buf.Reset() + if err = ent.body.Execute(buf, data); err != nil { + return "", "", err + } + body = buf.String() + return subject, body, nil +} + +type Cache struct { + sf singleflight.Group + now func() time.Time + + // Must hold rw for below field access + rw sync.RWMutex + m map[string]*tplCacheEntry // map[TemplateType]*tplCacheEntry + t time.Time // Time of the most recent call to getEntry +} + +func NewCache() *Cache { + return &Cache{ + m: make(map[string]*tplCacheEntry), + now: time.Now, + } +} + +func (o *Cache) Reload( + ctx context.Context, + cfg *conf.GlobalConfiguration, +) { + now := o.now() + touchedAt := o.getTouchedAt() + + // If the touchedAt time is zero we will eagerly reload. Note we must set + // the touch time to prevent a server that has never had a request from + // from reloading indefinitely. + if touchedAt.IsZero() { + defer o.setTouchedAt(now) + + o.reloadAt(ctx, cfg, now) + return + } + + // If the server has been idle for maxIdle time, we stop updating the + // templates until the next mail request comes through. + maxIdle := cfg.Mailer.TemplateReloadingMaxIdle + if now.Sub(touchedAt) >= maxIdle { + return + } + + o.reloadAt(ctx, cfg, now) +} + +func (o *Cache) reloadAt( + ctx context.Context, + cfg *conf.GlobalConfiguration, + now time.Time, +) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + wg := new(sync.WaitGroup) + defer wg.Wait() + + for _, typ := range templateTypes { + ent, ok := o.getEntry(typ) + if !ok { + // Cache miss, straight to load with no current entry. + o.reloadType(ctx, cfg, wg, typ, nil) + continue + } + + // Before we eagerly reload the template we first make sure we are + // approaching it's expiration. The goal is to never have the requests + // block on the singleflight during regular mail requests. + // + // The def flag signals that the template is the default template. We + // skip this check if it's currently set to true, as we want to get a + // new template as soon soon as possible. + maxAge := cfg.Mailer.TemplateMaxAge - (cfg.Mailer.TemplateMaxAge / 10) + if !ent.def && now.Sub(ent.createdAt) < maxAge { + continue + } + + // We are approaching the expiration and need to eagerly reload. Before + // making the request we make sure we haven't recently checked the template + // using our ival configuration knob. This is just a simple way to give + // endpoints some breathing room instead of expo backoff with counters. + retryIval := cfg.Mailer.TemplateRetryInterval + if now.Sub(ent.checkedAt) < retryIval { + continue + } + + // This template type is eligible for reload. + o.reloadType(ctx, cfg, wg, typ, ent) + } +} + +func (o *Cache) reloadType( + ctx context.Context, + cfg *conf.GlobalConfiguration, + wg *sync.WaitGroup, + typ string, + cur *tplCacheEntry, +) { + wg.Add(1) + go func(typ string) { + defer wg.Done() + + ent, err := o.load(ctx, cfg, typ, cur) + if err != nil { + return + } + + if cur == nil || cur.createdAt != ent.createdAt { + le := observability.GetLogEntryFromContext(ctx).Entry + le.WithFields(logrus.Fields{ + "event": "templatemailer_reloader_template_update", + "mail_type": typ, + }).Infof("mailer: reloaded template type: %v", typ) + } + }(typ) +} + +func (o *Cache) getTouchedAt() time.Time { + o.rw.RLock() + defer o.rw.RUnlock() + return o.t +} + +func (o *Cache) setTouchedAt(at time.Time) { + o.rw.Lock() + defer o.rw.Unlock() + o.t = at +} + +func (o *Cache) getEntry(typ string) (*tplCacheEntry, bool) { + o.rw.RLock() + defer o.rw.RUnlock() + v, ok := o.m[typ] + return v, ok +} + +func (o *Cache) getEntryAndTouchAt(typ string, at time.Time) (*tplCacheEntry, bool) { + o.rw.Lock() + defer o.rw.Unlock() + o.t = at + v, ok := o.m[typ] + return v, ok +} + +func (o *Cache) putEntry(typ string, ent *tplCacheEntry) { + o.rw.Lock() + defer o.rw.Unlock() + o.m[typ] = ent +} + +// get is the method called to fetch an entry from the cache. +func (o *Cache) get( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*tplCacheEntry, error) { + now := o.now() + ent, ok := o.getEntryAndTouchAt(typ, now) + if !ok { + // Cache miss, straight to load with no current entry. + return o.load(ctx, cfg, typ, nil) + } + + maxAge := cfg.Mailer.TemplateMaxAge + if now.Sub(ent.createdAt) < maxAge { + // Cache hit and the entry is not expired, return it. + return ent, nil + } + + // Entry is expired, we check if the entry is ready for reloading. We do + // as much as we can outside of load to prevent synchronization on o.sf. + retryIval := cfg.Mailer.TemplateRetryInterval + if now.Sub(ent.checkedAt) < retryIval { + // Entry was checked within maxIval, return it. + return ent, nil + } + + // Call load with our most recent entry. + return o.load(ctx, cfg, typ, ent) +} + +// load is what happens when "get" has a cache miss, the hit has expired or +// the a previously failed check has elapsed the ival. +func (o *Cache) load( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, + cur *tplCacheEntry, +) (*tplCacheEntry, error) { + + // Before load returns, forget the most recent result of sf. Because we + // write our cache result in Do we guarantee that the next call to SF + // after this function returns will be a cache hit. + defer o.sf.Forget(typ) + + // We prevent a recently restarted auth server from sending multiple + // concurrent requests to the templating endpoint with pkg singleflight. + v, err, _ := o.sf.Do(typ, func() (any, error) { + + // First try to load a fresh entry. + ent, err := o.loadEntry(ctx, cfg, typ) + if err == nil { + // No error fetching fresh entry, put in cache & return it. + o.putEntry(typ, ent) + return ent, nil + } + + // We had an err loading a fresh entry. Check if we had a current entry + // and return a copy of that with a last checked time. + if cur != nil { + cpy := cur.copy() + cpy.checkedAt = o.now() + + o.putEntry(typ, cpy) + return cpy, nil + } + + // We have no previous entry and no fresh entry, we will load the + // default templates so the mailer can continue serving requests. + ent = o.loadEntryDefault(typ) + o.putEntry(typ, ent) + return ent, nil + }) + if err != nil { + // I don't believe SF returns an error unless the fn it calls does, so + // this is mostly a defensive check. + err = wrapError(ctx, typ, "internal_error", err) + return nil, err + } + + // v is always a *tplCacheEntry + return v.(*tplCacheEntry), nil +} + +// loadEntry returns the +func (o *Cache) loadEntry( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*tplCacheEntry, error) { + subjectTemp, err := o.loadEntrySubject(ctx, cfg, typ) + if err != nil { + return nil, err + } + + bodyTemp, err := o.loadEntryBody(ctx, cfg, typ) + if err != nil { + return nil, err + } + + now := o.now() + ent := newTplCacheEntry(now, typ, subjectTemp, bodyTemp) + return ent, nil +} + +// loadEntryDefault will never fail due to the checkDefaults() in init(). +func (o *Cache) loadEntryDefault( + typ string, +) *tplCacheEntry { + subjectStr := getEmailContentConfig(defaultTemplateSubjects, typ, "") + subjectTemp := template.Must(template.New("").Parse(subjectStr)) + + bodyStr := getEmailContentConfig(defaultTemplateBodies, typ, "") + bodyTemp := template.Must(template.New("").Parse(bodyStr)) + + now := o.now() + ent := newTplCacheEntry(now, typ, subjectTemp, bodyTemp) + ent.def = true + return ent +} + +func (o *Cache) loadEntrySubject( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*template.Template, error) { + + // This matches the existing behavior, which allow for a potential double + // parse of the default but it's a minor cost for clean control flow. + tempStr := getEmailContentConfig( + &cfg.Mailer.Subjects, + typ, + getEmailContentConfig(defaultTemplateSubjects, typ, "")) + + temp, err := template.New("Subject").Parse(tempStr) + if err != nil { + err = wrapError(ctx, typ, "template_subject_parse_error", err) + return nil, err + } + return temp, nil +} + +func (o *Cache) loadEntryBody( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*template.Template, error) { + url := getEmailContentConfig(&cfg.Mailer.Templates, typ, "") + if url == "" { + + // We preserve the previous behavior of returning the default. + tempStr := getEmailContentConfig(defaultTemplateBodies, typ, "") + temp := template.Must(template.New("").Parse(tempStr)) + return temp, nil + } + if !strings.HasPrefix(url, "http") { + url = cfg.SiteURL + url + } + + tempStr, err := o.fetch(ctx, cfg, url) + if err != nil { + err = wrapError(ctx, typ, "template_body_http_error", err) + return nil, err + } + + temp, err := template.New(url).Parse(tempStr) + if err != nil { + err = wrapError(ctx, typ, "template_body_parse_error", err) + return nil, err + } + return temp, nil +} + +func (m *Cache) fetch(ctx context.Context, cfg *conf.GlobalConfiguration, url string) (string, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + + if code := res.StatusCode; code != http.StatusOK { + return "", fmt.Errorf("http: GET %v: status code %d", url, code) + } + + rdr := io.LimitReader(res.Body, int64(cfg.Mailer.TemplateMaxSize)) + data, err := io.ReadAll(rdr) + if err != nil { + return "", err + } + + body := string(data) + return body, nil +} + +func wrapError(ctx context.Context, typ, label string, err error) error { + if err == nil { + return nil + } + + err = fmt.Errorf( + "templatemailer: template type %q: %w", typ, err) + le := observability.GetLogEntryFromContext(ctx).Entry + le.WithFields(logrus.Fields{ + "event": "templatemailer_" + label, + "mail_type": typ, + }).WithError(err).Error(err) + return err +} + +func lookupEmailContentConfig( + cfg *conf.EmailContentConfiguration, + tpl string, +) (string, bool) { + switch tpl { + default: + return "", false + case InviteTemplate: + return cfg.Invite, true + case ConfirmationTemplate: + return cfg.Confirmation, true + case RecoveryTemplate: + return cfg.Recovery, true + case EmailChangeTemplate: + return cfg.EmailChange, true + case ReauthenticationTemplate: + return cfg.Reauthentication, true + case MagicLinkTemplate: + return cfg.MagicLink, true + } +} + +func getEmailContentConfig( + cfg *conf.EmailContentConfiguration, + tpl string, + def string, +) string { + // This matches behavior of old withDefault ("" != v) + if v, ok := lookupEmailContentConfig(cfg, tpl); ok && v != "" { + return v + } + return def +} + +func checkDefaults() error { + seen := make(map[string]bool) + data := map[string]any{ + "ConfirmationURL": "ConfirmationURL", + "Data": "Data", + "Email": "Email", + "NewEmail": "NewEmail", + "RedirectTo": "RedirectTo", + "SendingTo": "SendingTo", + "SiteURL": "SiteURL", + "Token": "Token", + "TokenHash": "TokenHash", + } + + buf := new(bytes.Buffer) + check := func(cfg *conf.EmailContentConfiguration, typ string) error { + defer buf.Reset() + + tempStr, ok := lookupEmailContentConfig(cfg, typ) + if !ok { + return fmt.Errorf( + "templatemailer: template type %q: missing default body template", typ) + } + + temp, err := template.New(typ).Parse(tempStr) + if err != nil { + return err + } + + if err := temp.Execute(buf, data); err != nil { + return err + } + return nil + } + + for _, typ := range templateTypes { + if seen[typ] { + return fmt.Errorf( + "templatemailer: template type %q: duplicate found", typ) + } + seen[typ] = true + + if err := check(defaultTemplateSubjects, typ); err != nil { + return err + } + if err := check(defaultTemplateBodies, typ); err != nil { + return err + } + } + return nil +} diff --git a/internal/mailer/templatemailer/templatemailer.go b/internal/mailer/templatemailer/templatemailer.go index 9a9eb78b8..c3c96eccd 100644 --- a/internal/mailer/templatemailer/templatemailer.go +++ b/internal/mailer/templatemailer/templatemailer.go @@ -8,65 +8,17 @@ import ( "strings" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/models" ) -// New will return a *TemplateMailer backed by the given mailer.Client. -func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client) *TemplateMailer { - return &TemplateMailer{ - SiteURL: globalConfig.SiteURL, - Config: globalConfig, - mc: mc, - } -} - -// TemplateMailer will send mail and use templates from the site for easy mail styling -type TemplateMailer struct { - SiteURL string - Config *conf.GlobalConfiguration - mc mailer.Client -} - -type emailParams struct { - Token string - Type string - RedirectTo string -} - -func getPath(filepath string, params *emailParams) (*url.URL, error) { - path := &url.URL{} - if filepath != "" { - if p, err := url.Parse(filepath); err != nil { - return nil, err - } else { - path = p - } - } - if params != nil { - path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) - } - return path, nil -} - -func encodeRedirectURL(referrerURL string) string { - if len(referrerURL) > 0 { - if strings.ContainsAny(referrerURL, "&=#") { - // if the string contains &, = or # it has not been URL - // encoded by the caller, which means it should be URL - // encoded by us otherwise, it should be taken as-is - referrerURL = url.QueryEscape(referrerURL) - } - } - return referrerURL -} - -func withDefault(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} +const ( + InviteTemplate = "invite" + ConfirmationTemplate = "confirmation" + RecoveryTemplate = "recovery" + EmailChangeTemplate = "email_change" + MagicLinkTemplate = "magic_link" + ReauthenticationTemplate = "reauthentication" +) const defaultInviteMail = `

You have been invited

@@ -74,7 +26,7 @@ const defaultInviteMail = `

You have been invited

Accept the invite

Alternatively, enter the code: {{ .Token }}

` -const defaultConfirmationMail = `

Confirm your email

+const defaultConfirmationMail = `

Confirm Your Email

Follow this link to confirm your email:

Confirm your email address

@@ -103,8 +55,35 @@ const defaultReauthenticateMail = `

Confirm reauthentication

Enter the code: {{ .Token }}

` -func (m *TemplateMailer) Headers(messageType string) map[string][]string { - originalHeaders := m.Config.SMTP.NormalizedHeaders() +var ( + templateTypes = []string{ + InviteTemplate, + ConfirmationTemplate, + RecoveryTemplate, + EmailChangeTemplate, + MagicLinkTemplate, + ReauthenticationTemplate, + } + defaultTemplateSubjects = &conf.EmailContentConfiguration{ + Invite: "You have been invited", + Confirmation: "Confirm Your Email", + Recovery: "Reset Your Password", + MagicLink: "Your Magic Link", + EmailChange: "Confirm Email Change", + Reauthentication: "Confirm reauthentication", + } + defaultTemplateBodies = &conf.EmailContentConfiguration{ + Invite: defaultInviteMail, + Confirmation: defaultConfirmationMail, + Recovery: defaultRecoveryMail, + MagicLink: defaultMagicLinkMail, + EmailChange: defaultEmailChangeMail, + Reauthentication: defaultReauthenticateMail, + } +) + +func (m *Mailer) Headers(cfg *conf.GlobalConfiguration, messageType string) map[string][]string { + originalHeaders := cfg.SMTP.NormalizedHeaders() if originalHeaders == nil { return nil @@ -139,8 +118,8 @@ func (m *TemplateMailer) Headers(messageType string) map[string][]string { } // InviteMail sends a invite mail to a new user -func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Invite, &emailParams{ +func (m *Mailer) InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Invite, &emailParams{ Token: user.ConfirmationToken, Type: "invite", RedirectTo: referrerURL, @@ -151,7 +130,7 @@ func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, ref } data := map[string]any{ - "SiteURL": m.Config.SiteURL, + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -159,22 +138,12 @@ func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, ref "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.mc.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Invite, "You have been invited"), - m.Config.Mailer.Templates.Invite, - defaultInviteMail, - data, - m.Headers("invite"), - "invite", - ) + return m.mail(r.Context(), m.cfg, InviteTemplate, user.GetEmail(), data) } // ConfirmationMail sends a signup confirmation mail to a new user -func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Confirmation, &emailParams{ +func (m *Mailer) ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Confirmation, &emailParams{ Token: user.ConfirmationToken, Type: "signup", RedirectTo: referrerURL, @@ -184,7 +153,7 @@ func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, ot } data := map[string]any{ - "SiteURL": m.Config.SiteURL, + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -192,67 +161,42 @@ func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, ot "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.mc.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Your Email"), - m.Config.Mailer.Templates.Confirmation, - defaultConfirmationMail, - data, - m.Headers("confirm"), - "confirm", - ) + return m.mail(r.Context(), m.cfg, ConfirmationTemplate, user.GetEmail(), data) } // ReauthenticateMail sends a reauthentication mail to an authenticated user -func (m *TemplateMailer) ReauthenticateMail(r *http.Request, user *models.User, otp string) error { +func (m *Mailer) ReauthenticateMail(r *http.Request, user *models.User, otp string) error { data := map[string]any{ - "SiteURL": m.Config.SiteURL, + "SiteURL": m.cfg.SiteURL, "Email": user.Email, "Token": otp, "Data": user.UserMetaData, } - - return m.mc.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Reauthentication, "Confirm reauthentication"), - m.Config.Mailer.Templates.Reauthentication, - defaultReauthenticateMail, - data, - m.Headers("reauthenticate"), - "reauthenticate", - ) + return m.mail(r.Context(), m.cfg, ReauthenticationTemplate, user.GetEmail(), data) } // EmailChangeMail sends an email change confirmation mail to a user -func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error { +func (m *Mailer) EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error { type Email struct { + Action string Address string Otp string TokenHash string - Subject string - Template string } emails := []Email{ { Address: user.EmailChange, Otp: otpNew, TokenHash: user.EmailChangeTokenNew, - Subject: withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), - Template: m.Config.Mailer.Templates.EmailChange, }, } currentEmail := user.GetEmail() - if m.Config.Mailer.SecureEmailChangeEnabled && currentEmail != "" { + if m.cfg.Mailer.SecureEmailChangeEnabled && currentEmail != "" { emails = append(emails, Email{ Address: currentEmail, Otp: otpCurrent, TokenHash: user.EmailChangeTokenCurrent, - Subject: withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Email Address"), - Template: m.Config.Mailer.Templates.EmailChange, }) } @@ -262,7 +206,7 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp errors := make(chan error, len(emails)) for _, email := range emails { path, err := getPath( - m.Config.Mailer.URLPaths.EmailChange, + m.cfg.Mailer.URLPaths.EmailChange, &emailParams{ Token: email.TokenHash, Type: "email_change", @@ -272,9 +216,9 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp if err != nil { return err } - go func(address, token, tokenHash, template string) { + go func(address, token, tokenHash string) { data := map[string]any{ - "SiteURL": m.Config.SiteURL, + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.GetEmail(), "NewEmail": user.EmailChange, @@ -284,17 +228,14 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp "Data": user.UserMetaData, "RedirectTo": referrerURL, } - errors <- m.mc.Mail( + errors <- m.mail( ctx, + m.cfg, + EmailChangeTemplate, address, - withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), - template, - defaultEmailChangeMail, data, - m.Headers("email_change"), - "email_change", ) - }(email.Address, email.Otp, email.TokenHash, email.Template) + }(email.Address, email.Otp, email.TokenHash) } for i := 0; i < len(emails); i++ { @@ -307,8 +248,8 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp } // RecoveryMail sends a password recovery mail -func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ +func (m *Mailer) RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "recovery", RedirectTo: referrerURL, @@ -317,7 +258,7 @@ func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, r return err } data := map[string]any{ - "SiteURL": m.Config.SiteURL, + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -325,22 +266,12 @@ func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, r "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.mc.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Recovery, "Reset Your Password"), - m.Config.Mailer.Templates.Recovery, - defaultRecoveryMail, - data, - m.Headers("recovery"), - "recovery", - ) + return m.mail(r.Context(), m.cfg, RecoveryTemplate, user.GetEmail(), data) } // MagicLinkMail sends a login link mail -func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ +func (m *Mailer) MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "magiclink", RedirectTo: referrerURL, @@ -350,7 +281,7 @@ func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, } data := map[string]any{ - "SiteURL": m.Config.SiteURL, + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -358,57 +289,47 @@ func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.mc.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.MagicLink, "Your Magic Link"), - m.Config.Mailer.Templates.MagicLink, - defaultMagicLinkMail, - data, - m.Headers("magiclink"), - "magiclink", - ) + return m.mail(r.Context(), m.cfg, MagicLinkTemplate, user.GetEmail(), data) } // GetEmailActionLink returns a magiclink, recovery or invite link based on the actionType passed. -func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) { +func (m *Mailer) GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) { var err error var path *url.URL switch actionType { case "magiclink": - path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "magiclink", RedirectTo: referrerURL, }) case "recovery": - path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &emailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "recovery", RedirectTo: referrerURL, }) case "invite": - path, err = getPath(m.Config.Mailer.URLPaths.Invite, &emailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Invite, &emailParams{ Token: user.ConfirmationToken, Type: "invite", RedirectTo: referrerURL, }) case "signup": - path, err = getPath(m.Config.Mailer.URLPaths.Confirmation, &emailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Confirmation, &emailParams{ Token: user.ConfirmationToken, Type: "signup", RedirectTo: referrerURL, }) case "email_change_current": - path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &emailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.EmailChange, &emailParams{ Token: user.EmailChangeTokenCurrent, Type: "email_change", RedirectTo: referrerURL, }) case "email_change_new": - path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &emailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.EmailChange, &emailParams{ Token: user.EmailChangeTokenNew, Type: "email_change", RedirectTo: referrerURL, @@ -421,3 +342,36 @@ func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referr } return externalURL.ResolveReference(path).String(), nil } + +type emailParams struct { + Token string + Type string + RedirectTo string +} + +func getPath(filepath string, params *emailParams) (*url.URL, error) { + path := &url.URL{} + if filepath != "" { + if p, err := url.Parse(filepath); err != nil { + return nil, err + } else { + path = p + } + } + if params != nil { + path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) + } + return path, nil +} + +func encodeRedirectURL(referrerURL string) string { + if len(referrerURL) > 0 { + if strings.ContainsAny(referrerURL, "&=#") { + // if the string contains &, = or # it has not been URL + // encoded by the caller, which means it should be URL + // encoded by us otherwise, it should be taken as-is + referrerURL = url.QueryEscape(referrerURL) + } + } + return referrerURL +} diff --git a/internal/mailer/templatemailer/templatemailer_test.go b/internal/mailer/templatemailer/templatemailer_test.go index 22a1c5cf5..22837005f 100644 --- a/internal/mailer/templatemailer/templatemailer_test.go +++ b/internal/mailer/templatemailer/templatemailer_test.go @@ -50,16 +50,15 @@ func TestTemplateHeaders(t *testing.T) { }, } for _, tc := range cases { - mailer := TemplateMailer{ - Config: &conf.GlobalConfiguration{ - SMTP: conf.SMTPConfiguration{ - Headers: tc.from, - }, + mailer := New(&conf.GlobalConfiguration{ + SMTP: conf.SMTPConfiguration{ + Headers: tc.from, }, - } - require.NoError(t, mailer.Config.SMTP.Validate()) + }, nil, nil) - hdrs := mailer.Headers(tc.typ) + require.NoError(t, mailer.cfg.SMTP.Validate()) + + hdrs := mailer.Headers(mailer.cfg, tc.typ) require.Equal(t, hdrs, tc.exp) } } diff --git a/internal/mailer/validateclient/validateclient.go b/internal/mailer/validateclient/validateclient.go index 6a0df04cc..b194bf111 100644 --- a/internal/mailer/validateclient/validateclient.go +++ b/internal/mailer/validateclient/validateclient.go @@ -101,10 +101,8 @@ type emailValidatorMailClient struct { func (o *emailValidatorMailClient) Mail( ctx context.Context, to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]any, + subject string, + body string, headers map[string][]string, typ string, ) error { @@ -114,10 +112,8 @@ func (o *emailValidatorMailClient) Mail( return o.mc.Mail( ctx, to, - subjectTemplate, - templateURL, - defaultTemplate, - templateData, + subject, + body, headers, typ, )