diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index bc841e849..6b2dee94c 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -14,7 +14,9 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/api/apiworker" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer/templatemailer" "github.com/supabase/auth/internal/reloader" "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/utilities" @@ -48,22 +50,24 @@ func serve(ctx context.Context) { } defer db.Close() - addr := net.JoinHostPort(config.API.Host, config.API.Port) - - opts := []api.Option{ - api.NewLimiterOptions(config), - } - baseCtx, baseCancel := context.WithCancel(context.Background()) defer baseCancel() var wg sync.WaitGroup defer wg.Wait() // Do not return to caller until this goroutine is done. - a := api.NewAPIWithVersion(config, db, utilities.Version, opts...) - ah := reloader.NewAtomicHandler(a) - logrus.WithField("version", a.Version()).Infof("GoTrue API started on: %s", addr) + mrCache := templatemailer.NewCache() + limiterOpts := api.NewLimiterOptions(config) + initialAPI := api.NewAPIWithVersion( + config, db, utilities.Version, + limiterOpts, + api.WithMailer(templatemailer.FromConfig(config, mrCache)), + ) + addr := net.JoinHostPort(config.API.Host, config.API.Port) + logrus.WithField("version", initialAPI.Version()).Infof("GoTrue API started on: %s", addr) + + ah := reloader.NewAtomicHandler(initialAPI) httpSrv := &http.Server{ Addr: addr, Handler: ah, @@ -74,6 +78,26 @@ func serve(ctx context.Context) { } log := logrus.WithField("component", "api") + wrkLog := logrus.WithField("component", "apiworker") + wrk := apiworker.New(config, mrCache, wrkLog) + wg.Add(1) + go func() { + defer wg.Done() + + var err error + defer func() { + logFn := wrkLog.Info + if err != nil { + logFn = wrkLog.WithError(err).Error + } + logFn("background apiworker is exiting") + }() + + // Work exits when ctx is done as in-flight requests do not depend + // on it. If they do in the future this should be baseCtx instead. + err = wrk.Work(ctx) + }() + if watchDir != "" { wg.Add(1) go func() { @@ -81,8 +105,26 @@ func serve(ctx context.Context) { fn := func(latestCfg *conf.GlobalConfiguration) { log.Info("reloading api with new configuration") + + // When config is updated we notify the apiworker. + wrk.ReloadConfig(latestCfg) + + // Create a new API version with the updated config. latestAPI := api.NewAPIWithVersion( - latestCfg, db, utilities.Version, opts...) + config, db, utilities.Version, + + // Create a new mailer with existing template cache. + api.WithMailer( + templatemailer.FromConfig(config, mrCache), + ), + + // Persist existing rate limiters. + // + // TODO(cstockton): we should consider updating these, if we + // rely on hot config reloads 100% then rate limiter changes + // won't be picked up. + limiterOpts, + ) ah.Store(latestAPI) } diff --git a/internal/api/api.go b/internal/api/api.go index 1d7af7772..63e7ee1f3 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -16,6 +16,7 @@ import ( "github.com/supabase/auth/internal/hooks/hookspgfunc" "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/mailer/templatemailer" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" @@ -38,11 +39,11 @@ type API struct { config *conf.GlobalConfiguration version string - hooksMgr *v0hooks.Manager - hibpClient *hibp.PwnedClient - oauthServer *oauthserver.Server - mailerClientFunc func() mailer.MailClient - tokenService *tokens.Service + hooksMgr *v0hooks.Manager + hibpClient *hibp.PwnedClient + oauthServer *oauthserver.Server + tokenService *tokens.Service + mailer mailer.Mailer // overrideTime can be used to override the clock used by handlers. Should only be used in tests! overrideTime func() time.Time @@ -53,6 +54,7 @@ type API struct { func (a *API) GetConfig() *conf.GlobalConfiguration { return a.config } func (a *API) GetDB() *storage.Connection { return a.db } func (a *API) GetTokenService() *tokens.Service { return a.tokenService } +func (a *API) Mailer() mailer.Mailer { return a.mailer } func (a *API) Version() string { return a.version @@ -99,11 +101,6 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if api.limiterOpts == nil { api.limiterOpts = NewLimiterOptions(globalConfig) } - if api.mailerClientFunc == nil { - api.mailerClientFunc = func() mailer.MailClient { - return mailer.NewMailClient(globalConfig) - } - } if api.hooksMgr == nil { httpDr := hookshttp.New() pgfuncDr := hookspgfunc.New(db) @@ -114,6 +111,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if api.tokenService == nil { api.tokenService = tokens.NewService(globalConfig, api.hooksMgr) } + if api.mailer == nil { + tc := templatemailer.NewCache() + api.mailer = templatemailer.FromConfig(globalConfig, tc) + } // Connect token service to API's time function (supports test overrides) api.tokenService.SetTimeFunc(api.Now) @@ -397,12 +398,6 @@ func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error { }) } -// Mailer returns NewMailer with the current tenant config -func (a *API) Mailer() mailer.Mailer { - config := a.config - return mailer.NewMailerWithClient(config, a.mailerClientFunc()) -} - // ServeHTTP implements the http.Handler interface by passing the request along // to its underlying Handler. func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/internal/api/apiworker/apiworker.go b/internal/api/apiworker/apiworker.go new file mode 100644 index 000000000..e6627c9f4 --- /dev/null +++ b/internal/api/apiworker/apiworker.go @@ -0,0 +1,111 @@ +package apiworker + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer/templatemailer" +) + +// Worker is a simple background worker for async tasks. +type Worker struct { + le *logrus.Entry + tc *templatemailer.Cache + + // Notifies worker the cfg has been updated. + cfgCh chan struct{} + + // workMu must be held for calls to Work + workMu sync.Mutex + + // mu must be held for field access below here + mu sync.Mutex + cfg *conf.GlobalConfiguration +} + +// New will return a new *Worker instance. +func New( + cfg *conf.GlobalConfiguration, + tc *templatemailer.Cache, + le *logrus.Entry, +) *Worker { + return &Worker{ + le: le, + cfg: cfg, + tc: tc, + cfgCh: make(chan struct{}, 1), + } +} + +func (o *Worker) putConfig(cfg *conf.GlobalConfiguration) { + o.mu.Lock() + defer o.mu.Unlock() + o.cfg = cfg +} + +func (o *Worker) getConfig() *conf.GlobalConfiguration { + o.mu.Lock() + defer o.mu.Unlock() + return o.cfg +} + +// ReloadConfig notifies the worker a new configuration is available. +func (o *Worker) ReloadConfig(cfg *conf.GlobalConfiguration) { + o.putConfig(cfg) + + select { + case o.cfgCh <- struct{}{}: + default: + } +} + +// Work will periodically reload the templates in the background as long as the +// system remains active. +func (o *Worker) Work(ctx context.Context) error { + if ok := o.workMu.TryLock(); !ok { + return errors.New("apiworker: concurrent calls to Work are invalid") + } + defer o.workMu.Unlock() + + le := o.le.WithFields(logrus.Fields{ + "worker_type": "apiworker_template_cache", + }) + le.Info("apiworker: template cache worker started") + defer le.Info("apiworker: template cache worker exited") + + // Reload templates right away on Work. + o.maybeReloadTemplates(ctx, o.getConfig()) + + ival := func() time.Duration { + return max(time.Second, o.getConfig().Mailer.TemplateRetryInterval/4) + } + + tr := time.NewTicker(ival()) + defer tr.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-o.cfgCh: + tr.Reset(ival()) + case <-tr.C: + } + + // Either ticker fired or we got a config update. + o.maybeReloadTemplates(ctx, o.getConfig()) + } +} + +func (o *Worker) maybeReloadTemplates( + ctx context.Context, + cfg *conf.GlobalConfiguration, +) { + if cfg.Mailer.TemplateReloadingEnabled { + o.tc.Reload(ctx, cfg) + } +} diff --git a/internal/api/mail.go b/internal/api/mail.go index 569ecf726..cfbe57397 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -8,6 +8,7 @@ import ( "github.com/supabase/auth/internal/hooks/v0hooks" mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/mailer/validateclient" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -705,9 +706,9 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, } switch { - case errors.Is(err, mail.ErrInvalidEmailAddress), - errors.Is(err, mail.ErrInvalidEmailFormat), - errors.Is(err, mail.ErrInvalidEmailDNS): + case errors.Is(err, validateclient.ErrInvalidEmailAddress), + errors.Is(err, validateclient.ErrInvalidEmailFormat), + errors.Is(err, validateclient.ErrInvalidEmailDNS): return apierrors.NewBadRequestError( apierrors.ErrorCodeEmailAddressInvalid, "Email address %q is invalid", diff --git a/internal/api/options.go b/internal/api/options.go index c5ab33fb5..d47aba276 100644 --- a/internal/api/options.go +++ b/internal/api/options.go @@ -15,11 +15,21 @@ type Option interface { apply(*API) } -type MailerOptions struct { - MailerClientFunc func() mailer.MailClient +type optionFunc func(*API) + +func (f optionFunc) apply(a *API) { f(a) } + +func WithMailer(m mailer.Mailer) Option { + return optionFunc(func(a *API) { + a.mailer = m + }) } -func (mo *MailerOptions) apply(a *API) { a.mailerClientFunc = mo.MailerClientFunc } +func WithTokenService(service *tokens.Service) Option { + return optionFunc(func(a *API) { + a.tokenService = service + }) +} type LimiterOptions struct { Email ratelimit.Limiter @@ -44,19 +54,6 @@ type LimiterOptions struct { func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } -// TokenServiceOption allows injecting a custom token service -type TokenServiceOption struct { - service *tokens.Service -} - -func WithTokenService(service *tokens.Service) *TokenServiceOption { - return &TokenServiceOption{service: service} -} - -func (tso *TokenServiceOption) apply(a *API) { - a.tokenService = tso.service -} - func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { o := &LimiterOptions{} diff --git a/internal/api/options_test.go b/internal/api/options_test.go index 3785f6426..c4c1d1623 100644 --- a/internal/api/options_test.go +++ b/internal/api/options_test.go @@ -4,10 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/supabase/auth/internal/conf" - "github.com/supabase/auth/internal/e2e" - "github.com/supabase/auth/internal/mailer" ) func TestNewLimiterOptions(t *testing.T) { @@ -31,17 +28,3 @@ func TestNewLimiterOptions(t *testing.T) { assert.NotNil(t, rl.SSO) assert.NotNil(t, rl.SAMLAssertion) } - -func TestMailerOptions(t *testing.T) { - globalCfg := e2e.Must(e2e.Config()) - conn := e2e.Must(e2e.Conn(globalCfg)) - - sentinelMailer := mailer.NewMailClient(globalCfg) - mailerOpts := &MailerOptions{MailerClientFunc: func() mailer.MailClient { - return sentinelMailer - }} - a := NewAPIWithVersion(globalCfg, conn, apiTestVersion, mailerOpts) - - got := a.mailerClientFunc() - require.Equal(t, sentinelMailer, got) -} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index e59460c0e..6912a53e6 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -449,13 +449,31 @@ type MailerConfiguration struct { ExternalHosts []string `json:"external_hosts" split_words:"true"` - // EXPERIMENTAL: May be removed in a future release. + // EXPERIMENTAL: All config below here may be removed in a future release. EmailBackgroundSending bool `json:"email_background_sending" split_words:"true" default:"false"` EmailValidationExtended bool `json:"email_validation_extended" split_words:"true" default:"false"` EmailValidationServiceURL string `json:"email_validation_service_url" split_words:"true"` EmailValidationServiceHeaders string `json:"email_validation_service_headers" split_words:"true"` EmailValidationBlockedMX string `json:"email_validation_blocked_mx" split_words:"true"` + // Max size in bytes we will read from a template endpoint + TemplateMaxSize int `json:"template_max_size" split_words:"true" default:"1000000"` + + // The maximum age of a template before we consider it stale. + TemplateMaxAge time.Duration `json:"template_max_age" split_words:"true" default:"10m"` + + // The time between retrying a failed template reload. + TemplateRetryInterval time.Duration `json:"template_retry_interval" split_words:"true" default:"10s"` + + // If true enable background reloading of templates to avoid blocking + // IO in requests. + TemplateReloadingEnabled bool `json:"template_reloading_enabled" split_words:"true" default:"false"` + + // The maximum time a server may be idle before template reloading stops. + // Note that even when the server is idle, a config reload will trigger a + // template reload. + TemplateReloadingMaxIdle time.Duration `json:"template_reloading_max_idle" split_words:"true" default:"20m"` + serviceHeaders map[string][]string `json:"-"` blockedMXRecords map[string]bool `json:"-"` } diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go index 5ce2dc0ca..5c8602578 100644 --- a/internal/mailer/mailer.go +++ b/internal/mailer/mailer.go @@ -2,16 +2,24 @@ package mailer import ( "context" - "fmt" "net/http" "net/url" - "github.com/sirupsen/logrus" - "github.com/supabase/auth/internal/api/apitask" - "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" ) +const ( + SignupVerification = "signup" + RecoveryVerification = "recovery" + InviteVerification = "invite" + MagicLinkVerification = "magiclink" + EmailChangeVerification = "email_change" + EmailOTPVerification = "email" + EmailChangeCurrentVerification = "email_change_current" + EmailChangeNewVerification = "email_change_new" + ReauthenticationVerification = "reauthentication" +) + // Mailer defines the interface a mailer must implement. type Mailer interface { InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error @@ -23,10 +31,16 @@ type Mailer interface { GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) } -type EmailParams struct { - Token string - Type string - RedirectTo string +// TODO(cstockton): Mail(...) -> Mail(Email{...}) ? +type Client interface { + Mail( + ctx context.Context, + to string, + subject string, + body string, + headers map[string][]string, + typ string, + ) error } type EmailData struct { @@ -38,178 +52,3 @@ type EmailData struct { TokenNew string `json:"token_new"` TokenHashNew string `json:"token_hash_new"` } - -// NewMailer returns a new gotrue mailer -func NewMailer(globalConfig *conf.GlobalConfiguration) Mailer { - mc := NewMailClient(globalConfig) - return NewMailerWithClient(globalConfig, mc) -} - -type emailValidatorMailClient struct { - ev *EmailValidator - mc MailClient -} - -// Mail implements mailer.MailClient interface by calling validate before -// passing the mail request to the next MailClient. -func (o *emailValidatorMailClient) Mail( - ctx context.Context, - to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]any, - headers map[string][]string, - typ string, -) error { - if err := o.ev.Validate(ctx, to); err != nil { - return err - } - return o.mc.Mail( - ctx, - to, - subjectTemplate, - templateURL, - defaultTemplate, - templateData, - headers, - typ, - ) -} - -// NewMailerWithClient returns a new Mailer that will use the given MailClient. -func NewMailerWithClient( - globalConfig *conf.GlobalConfiguration, - mailClient MailClient, -) Mailer { - return &TemplateMailer{ - SiteURL: globalConfig.SiteURL, - Config: globalConfig, - MailClient: mailClient, - } -} - -// NewMailClient returns a new MailClient based on the given configuration. -func NewMailClient(globalConfig *conf.GlobalConfiguration) MailClient { - mc := newMailClient(globalConfig) - - // Check if email validation is enabled - ev := newEmailValidator(globalConfig.Mailer) - if ev.isEnabled() { - mc = &emailValidatorMailClient{ev: ev, mc: mc} - } - - // Check if background emails are enabled - if globalConfig.Mailer.EmailBackgroundSending { - mc = &backgroundMailClient{ - mc: mc, - } - } - return mc -} - -// newMailClient returns a new MailClient based on the given configuration. -func newMailClient(globalConfig *conf.GlobalConfiguration) MailClient { - if globalConfig.SMTP.Host == "" { - logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) - return &noopMailClient{ - EmailValidator: newEmailValidator(globalConfig.Mailer), - } - } - - from := globalConfig.SMTP.FromAddress() - u, _ := url.ParseRequestURI(globalConfig.API.ExternalURL) - return &MailmeMailer{ - Host: globalConfig.SMTP.Host, - Port: globalConfig.SMTP.Port, - User: globalConfig.SMTP.User, - Pass: globalConfig.SMTP.Pass, - LocalName: u.Hostname(), - From: from, - BaseURL: globalConfig.SiteURL, - Logger: logrus.StandardLogger(), - MailLogging: globalConfig.SMTP.LoggingEnabled, - } -} - -func withDefault(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -func getPath(filepath string, params *EmailParams) (*url.URL, error) { - path := &url.URL{} - if filepath != "" { - if p, err := url.Parse(filepath); err != nil { - return nil, err - } else { - path = p - } - } - if params != nil { - path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) - } - return path, nil -} - -// Task holds a mail pending delivery by the Handler. -type Task struct { - mc MailClient - - To string `json:"to"` - SubjectTemplate string `json:"subject_template"` - TemplateURL string `json:"template_url"` - DefaultTemplate string `json:"default_template"` - TemplateData map[string]any `json:"template_data"` - Headers map[string][]string `json:"headers"` - Typ string `json:"typ"` -} - -// Run implements the Type method of the apitask.Task interface by returning -// the "mailer." prefix followed by the mail type. -func (o *Task) Type() string { return fmt.Sprintf("mailer.%v", o.Typ) } - -// Run implements the Run method of the apitask.Task interface by attempting -// to send the mail using the underying mail client. -func (o *Task) Run(ctx context.Context) error { - return o.mc.Mail( - ctx, - o.To, - o.SubjectTemplate, - o.TemplateURL, - o.DefaultTemplate, - o.TemplateData, - o.Headers, - o.Typ) -} - -type backgroundMailClient struct { - mc MailClient -} - -// Mail implements mailer.MailClient interface by sending the call to the -// wrapped mail client to the background. -func (o *backgroundMailClient) Mail( - ctx context.Context, - to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]any, - headers map[string][]string, - typ string, -) error { - tk := &Task{ - mc: o.mc, - To: to, - SubjectTemplate: subjectTemplate, - TemplateURL: templateURL, - DefaultTemplate: defaultTemplate, - TemplateData: templateData, - Headers: headers, - Typ: typ, - } - return apitask.Run(ctx, tk) -} diff --git a/internal/mailer/mailme.go b/internal/mailer/mailme.go deleted file mode 100644 index 7a62cf151..000000000 --- a/internal/mailer/mailme.go +++ /dev/null @@ -1,223 +0,0 @@ -package mailer - -import ( - "bytes" - "context" - "errors" - "html/template" - "io" - "log" - "net/http" - "strings" - "sync" - "time" - - "gopkg.in/gomail.v2" - - "github.com/sirupsen/logrus" -) - -// TemplateRetries is the amount of time MailMe will try to fetch a URL before giving up -const TemplateRetries = 3 - -// TemplateExpiration is the time period that the template will be cached for -const TemplateExpiration = 10 * time.Second - -// MailmeMailer lets MailMe send templated mails -type MailmeMailer struct { - From string - Host string - Port int - User string - Pass string - BaseURL string - LocalName string - FuncMap template.FuncMap - cache *TemplateCache - Logger logrus.FieldLogger - MailLogging bool -} - -// Mail sends a templated mail. It will try to load the template from a URL, and -// otherwise fall back to the default -func (m *MailmeMailer) Mail( - ctx context.Context, - to, subjectTemplate, templateURL, defaultTemplate string, - templateData map[string]interface{}, - headers map[string][]string, - typ string, -) error { - if m.FuncMap == nil { - m.FuncMap = map[string]interface{}{} - } - if m.cache == nil { - m.cache = &TemplateCache{ - templates: map[string]*MailTemplate{}, - funcMap: m.FuncMap, - logger: m.Logger, - } - } - - tmp, err := template.New("Subject").Funcs(template.FuncMap(m.FuncMap)).Parse(subjectTemplate) - if err != nil { - return err - } - - subject := &bytes.Buffer{} - err = tmp.Execute(subject, templateData) - if err != nil { - return err - } - - body, err := m.MailBody(templateURL, defaultTemplate, templateData) - if err != nil { - return err - } - - mail := gomail.NewMessage() - mail.SetHeader("From", m.From) - mail.SetHeader("To", to) - mail.SetHeader("Subject", subject.String()) - - for k, v := range headers { - if v != nil { - mail.SetHeader(k, v...) - } - } - - mail.SetBody("text/html", body) - - dial := gomail.NewDialer(m.Host, m.Port, m.User, m.Pass) - if m.LocalName != "" { - dial.LocalName = m.LocalName - } - - if m.MailLogging { - defer func() { - fields := logrus.Fields{ - "event": "mail.send", - "mail_type": typ, - "mail_from": m.From, - "mail_to": to, - } - m.Logger.WithFields(fields).Info("mail.send") - }() - } - if err := dial.DialAndSend(mail); err != nil { - return err - } - return nil -} - -type MailTemplate struct { - tmp *template.Template - expiresAt time.Time -} - -type TemplateCache struct { - templates map[string]*MailTemplate - mutex sync.Mutex - funcMap template.FuncMap - logger logrus.FieldLogger -} - -func (t *TemplateCache) Get(url string) (*template.Template, error) { - cached, ok := t.templates[url] - if ok && (cached.expiresAt.Before(time.Now())) { - return cached.tmp, nil - } - data, err := t.fetchTemplate(url, TemplateRetries) - if err != nil { - return nil, err - } - return t.Set(url, data, TemplateExpiration) -} - -func (t *TemplateCache) Set(key, value string, expirationTime time.Duration) (*template.Template, error) { - parsed, err := template.New(key).Funcs(t.funcMap).Parse(value) - if err != nil { - return nil, err - } - - cached := &MailTemplate{ - tmp: parsed, - expiresAt: time.Now().Add(expirationTime), - } - t.mutex.Lock() - t.templates[key] = cached - t.mutex.Unlock() - return parsed, nil -} - -func (t *TemplateCache) fetchTemplate(url string, triesLeft int) (string, error) { - client := &http.Client{ - Timeout: 10 * time.Second, - } - - resp, err := client.Get(url) - if err != nil && triesLeft > 0 { - return t.fetchTemplate(url, triesLeft-1) - } - if err != nil { - return "", err - } - defer resp.Body.Close() - if resp.StatusCode == 200 { // OK - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil && triesLeft > 0 { - return t.fetchTemplate(url, triesLeft-1) - } - if err != nil { - return "", err - } - return string(bodyBytes), err - } - if triesLeft > 0 { - return t.fetchTemplate(url, triesLeft-1) - } - return "", errors.New("mailer: unable to fetch mail template") -} - -func (m *MailmeMailer) MailBody(url string, defaultTemplate string, data map[string]interface{}) (string, error) { - if m.FuncMap == nil { - m.FuncMap = map[string]interface{}{} - } - if m.cache == nil { - m.cache = &TemplateCache{templates: map[string]*MailTemplate{}, funcMap: m.FuncMap} - } - - var temp *template.Template - var err error - - if url != "" { - var absoluteURL string - if strings.HasPrefix(url, "http") { - absoluteURL = url - } else { - absoluteURL = m.BaseURL + url - } - temp, err = m.cache.Get(absoluteURL) - if err != nil { - log.Printf("Error loading template from %v: %v\n", url, err) - } - } - - if temp == nil { - cached, ok := m.cache.templates[url] - if ok { - temp = cached.tmp - } else { - temp, err = m.cache.Set(url, defaultTemplate, 0) - if err != nil { - return "", err - } - } - } - - buf := &bytes.Buffer{} - err = temp.Execute(buf, data) - if err != nil { - return "", err - } - return buf.String(), nil -} diff --git a/internal/mailer/mailmeclient/mailmeclient.go b/internal/mailer/mailmeclient/mailmeclient.go new file mode 100644 index 000000000..0062a04ac --- /dev/null +++ b/internal/mailer/mailmeclient/mailmeclient.go @@ -0,0 +1,87 @@ +// Package mailmeclient provides an implementation of mailer.Client that uses +// gopkg.in/gomail.v2 to send via SMTP. +package mailmeclient + +import ( + "context" + "net/url" + + "gopkg.in/gomail.v2" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" +) + +// Client lets MailMe send templated mails +type Client struct { + From string + Host string + Port int + User string + Pass string + LocalName string + + Logger logrus.FieldLogger + MailLogging bool +} + +// New returns a new *Mailer based on the given configuration. +func New(globalConfig *conf.GlobalConfiguration) *Client { + from := globalConfig.SMTP.FromAddress() + u, _ := url.ParseRequestURI(globalConfig.API.ExternalURL) + return &Client{ + Host: globalConfig.SMTP.Host, + Port: globalConfig.SMTP.Port, + User: globalConfig.SMTP.User, + Pass: globalConfig.SMTP.Pass, + LocalName: u.Hostname(), + From: from, + Logger: logrus.StandardLogger(), + MailLogging: globalConfig.SMTP.LoggingEnabled, + } +} + +// Mail sends a templated mail. It will try to load the template from a URL, and +// otherwise fall back to the default +func (m *Client) Mail( + ctx context.Context, + to string, + subject string, + body string, + headers map[string][]string, + typ string, +) error { + mail := gomail.NewMessage() + mail.SetHeader("From", m.From) + mail.SetHeader("To", to) + mail.SetHeader("Subject", subject) + + for k, v := range headers { + if v != nil { + mail.SetHeader(k, v...) + } + } + + mail.SetBody("text/html", body) + + dial := gomail.NewDialer(m.Host, m.Port, m.User, m.Pass) + if m.LocalName != "" { + dial.LocalName = m.LocalName + } + + if m.MailLogging { + defer func() { + fields := logrus.Fields{ + "event": "mail.send", + "mail_type": typ, + "mail_from": m.From, + "mail_to": to, + } + m.Logger.WithFields(fields).Info("mail.send") + }() + } + if err := dial.DialAndSend(mail); err != nil { + return err + } + return nil +} diff --git a/internal/mailer/noop.go b/internal/mailer/noop.go deleted file mode 100644 index 1179df89b..000000000 --- a/internal/mailer/noop.go +++ /dev/null @@ -1,39 +0,0 @@ -package mailer - -import ( - "context" - "errors" - "time" -) - -type noopMailClient struct { - EmailValidator *EmailValidator - Delay time.Duration -} - -func (m *noopMailClient) Mail( - ctx context.Context, - to, subjectTemplate, templateURL, defaultTemplate string, - templateData map[string]interface{}, - headers map[string][]string, - typ string, -) error { - if to == "" { - return errors.New("to field cannot be empty") - } - - if m.Delay > 0 { - select { - case <-time.After(m.Delay): - case <-ctx.Done(): - return ctx.Err() - } - } - - if m.EmailValidator != nil { - if err := m.EmailValidator.Validate(ctx, to); err != nil { - return err - } - } - return nil -} diff --git a/internal/mailer/noopclient/noopclient.go b/internal/mailer/noopclient/noopclient.go new file mode 100644 index 000000000..e40bab202 --- /dev/null +++ b/internal/mailer/noopclient/noopclient.go @@ -0,0 +1,39 @@ +// Package noopclient provides an implementation of mailer.Client that simply +// does nothing. +package noopclient + +import ( + "context" + "errors" + "time" +) + +type Client struct { + Delay time.Duration +} + +func New() *Client { + return &Client{} +} + +func (m *Client) Mail( + ctx context.Context, + to string, + subject string, + body string, + headers map[string][]string, + typ string, +) error { + if to == "" { + return errors.New("to field cannot be empty") + } + + if m.Delay > 0 { + select { + case <-time.After(m.Delay): + case <-ctx.Done(): + return ctx.Err() + } + } + return nil +} diff --git a/internal/mailer/taskclient/taskclient.go b/internal/mailer/taskclient/taskclient.go new file mode 100644 index 000000000..e966d3c16 --- /dev/null +++ b/internal/mailer/taskclient/taskclient.go @@ -0,0 +1,76 @@ +// Package taskclient provides an implementation of mailer.Client that uses +// the apitask package to send mail in the background. +package taskclient + +import ( + "context" + "fmt" + + "github.com/supabase/auth/internal/api/apitask" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" +) + +// New will return a Client that runs a task in the background that will later +// call the given Client. If the mailer config EmailBackgroundSending is +// disabled it will return the same Client passed in mc. +func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client) mailer.Client { + + // Check if background emails are enabled + if globalConfig.Mailer.EmailBackgroundSending { + mc = &backgroundMailClient{mc: mc} + } + return mc +} + +// Task holds a mail pending delivery by the Handler. +type Task struct { + mc mailer.Client + + To string `json:"to"` + Subject string `json:"subject"` + Body string `json:"body"` + Headers map[string][]string `json:"headers"` + Typ string `json:"typ"` +} + +// Run implements the Type method of the apitask.Task interface by returning +// the "mailer." prefix followed by the mail type. +func (o *Task) Type() string { return fmt.Sprintf("mailer.%v", o.Typ) } + +// Run implements the Run method of the apitask.Task interface by attempting +// to send the mail using the underying mail client. +func (o *Task) Run(ctx context.Context) error { + return o.mc.Mail( + ctx, + o.To, + o.Subject, + o.Body, + o.Headers, + o.Typ) +} + +type backgroundMailClient struct { + mc mailer.Client +} + +// Mail implements mailer.MailClient interface by sending the call to the +// wrapped mail client to the background. +func (o *backgroundMailClient) Mail( + ctx context.Context, + to string, + subject string, + body string, + headers map[string][]string, + typ string, +) error { + tk := &Task{ + mc: o.mc, + To: to, + Subject: subject, + Body: body, + Headers: headers, + Typ: typ, + } + return apitask.Run(ctx, tk) +} diff --git a/internal/mailer/templatemailer/template.go b/internal/mailer/templatemailer/template.go new file mode 100644 index 000000000..0263a7700 --- /dev/null +++ b/internal/mailer/templatemailer/template.go @@ -0,0 +1,601 @@ +package templatemailer + +import ( + "bytes" + "context" + "fmt" + "html/template" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/mailer/mailmeclient" + "github.com/supabase/auth/internal/mailer/noopclient" + "github.com/supabase/auth/internal/mailer/taskclient" + "github.com/supabase/auth/internal/mailer/validateclient" + "github.com/supabase/auth/internal/observability" + "golang.org/x/sync/singleflight" +) + +func init() { + // Ensure every TemplateType has a default subject & body. + if err := checkDefaults(); err != nil { + panic(err) + } +} + +// Mailer will send mail and use templates from the site for easy mail styling +type Mailer struct { + cfg *conf.GlobalConfiguration + mc mailer.Client + tc *Cache +} + +// FromConfig returns a new mailer configured using the global configuration. +func FromConfig(globalConfig *conf.GlobalConfiguration, tc *Cache) *Mailer { + var mc mailer.Client + if globalConfig.SMTP.Host == "" { + logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) + mc = noopclient.New() + } else { + mc = mailmeclient.New(globalConfig) + } + + // Wrap client with validation first + mc = validateclient.New(globalConfig, mc) + + // Then background tasks + mc = taskclient.New(globalConfig, mc) + + // Finally the template mailer + return New(globalConfig, mc, tc) +} + +// New will return a *TemplateMailer backed by the given mailer.Client. +func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client, tc *Cache) *Mailer { + return &Mailer{ + cfg: globalConfig, + mc: mc, + tc: tc, + } +} + +func (m *Mailer) mail( + ctx context.Context, + cfg *conf.GlobalConfiguration, + tpl string, + to string, + data map[string]any, +) error { + if _, ok := lookupEmailContentConfig(&cfg.Mailer.Subjects, tpl); !ok { + return fmt.Errorf("templatemailer: template type: %s is invalid", tpl) + } + + // This is to match the previous behavior, which sent a "reauthenticate" + // header instead of the same name as template. + typ := tpl + if typ == ReauthenticationTemplate { + typ = "reauthenticate" + } + headers := m.Headers(cfg, typ) + + ent, err := m.tc.get(ctx, cfg, tpl) + if err != nil { + return err + } + + var buf bytes.Buffer + subject, body, err := ent.execute(&buf, data) + if err != nil { + return err + } + return m.mc.Mail( + ctx, + to, + subject, + body, + headers, + typ, + ) +} + +type tplCacheEntry struct { + createdAt time.Time + checkedAt time.Time + def bool + typ string + subject *template.Template + body *template.Template +} + +func newTplCacheEntry( + at time.Time, + typ string, + subject, body *template.Template, +) *tplCacheEntry { + return &tplCacheEntry{ + createdAt: at, + checkedAt: at, + typ: typ, + subject: subject, + body: body, + } +} + +func (ent *tplCacheEntry) copy() *tplCacheEntry { + cpy := *ent + return &cpy +} + +func (ent *tplCacheEntry) execute( + buf *bytes.Buffer, + data map[string]any, +) (subject string, body string, err error) { + if err = ent.subject.Execute(buf, data); err != nil { + return "", "", err + } + subject = buf.String() + + buf.Reset() + if err = ent.body.Execute(buf, data); err != nil { + return "", "", err + } + body = buf.String() + return subject, body, nil +} + +type Cache struct { + sf singleflight.Group + now func() time.Time + + // Must hold rw for below field access + rw sync.RWMutex + m map[string]*tplCacheEntry // map[TemplateType]*tplCacheEntry + t time.Time // Time of the most recent call to getEntry +} + +func NewCache() *Cache { + return &Cache{ + m: make(map[string]*tplCacheEntry), + now: time.Now, + } +} + +func (o *Cache) Reload( + ctx context.Context, + cfg *conf.GlobalConfiguration, +) { + now := o.now() + touchedAt := o.getTouchedAt() + + // If the touchedAt time is zero we will eagerly reload. Note we must set + // the touch time to prevent a server that has never had a request from + // from reloading indefinitely. + if touchedAt.IsZero() { + defer o.setTouchedAt(now) + + o.reloadAt(ctx, cfg, now) + return + } + + // If the server has been idle for maxIdle time, we stop updating the + // templates until the next mail request comes through. + maxIdle := cfg.Mailer.TemplateReloadingMaxIdle + if now.Sub(touchedAt) >= maxIdle { + return + } + + o.reloadAt(ctx, cfg, now) +} + +func (o *Cache) reloadAt( + ctx context.Context, + cfg *conf.GlobalConfiguration, + now time.Time, +) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + wg := new(sync.WaitGroup) + defer wg.Wait() + + for _, typ := range templateTypes { + ent, ok := o.getEntry(typ) + if !ok { + // Cache miss, straight to load with no current entry. + o.reloadType(ctx, cfg, wg, typ, nil) + continue + } + + // Before we eagerly reload the template we first make sure we are + // approaching it's expiration. The goal is to never have the requests + // block on the singleflight during regular mail requests. + // + // The def flag signals that the template is the default template. We + // skip this check if it's currently set to true, as we want to get a + // new template as soon soon as possible. + maxAge := cfg.Mailer.TemplateMaxAge - (cfg.Mailer.TemplateMaxAge / 10) + if !ent.def && now.Sub(ent.createdAt) < maxAge { + continue + } + + // We are approaching the expiration and need to eagerly reload. Before + // making the request we make sure we haven't recently checked the template + // using our ival configuration knob. This is just a simple way to give + // endpoints some breathing room instead of expo backoff with counters. + retryIval := cfg.Mailer.TemplateRetryInterval + if now.Sub(ent.checkedAt) < retryIval { + continue + } + + // This template type is eligible for reload. + o.reloadType(ctx, cfg, wg, typ, ent) + } +} + +func (o *Cache) reloadType( + ctx context.Context, + cfg *conf.GlobalConfiguration, + wg *sync.WaitGroup, + typ string, + cur *tplCacheEntry, +) { + wg.Add(1) + go func(typ string) { + defer wg.Done() + + ent, err := o.load(ctx, cfg, typ, cur) + if err != nil { + return + } + + if cur == nil || cur.createdAt != ent.createdAt { + le := observability.GetLogEntryFromContext(ctx).Entry + le.WithFields(logrus.Fields{ + "event": "templatemailer_reloader_template_update", + "mail_type": typ, + }).Infof("mailer: reloaded template type: %v", typ) + } + }(typ) +} + +func (o *Cache) getTouchedAt() time.Time { + o.rw.RLock() + defer o.rw.RUnlock() + return o.t +} + +func (o *Cache) setTouchedAt(at time.Time) { + o.rw.Lock() + defer o.rw.Unlock() + o.t = at +} + +func (o *Cache) getEntry(typ string) (*tplCacheEntry, bool) { + o.rw.RLock() + defer o.rw.RUnlock() + v, ok := o.m[typ] + return v, ok +} + +func (o *Cache) getEntryAndTouchAt(typ string, at time.Time) (*tplCacheEntry, bool) { + o.rw.Lock() + defer o.rw.Unlock() + o.t = at + v, ok := o.m[typ] + return v, ok +} + +func (o *Cache) putEntry(typ string, ent *tplCacheEntry) { + o.rw.Lock() + defer o.rw.Unlock() + o.m[typ] = ent +} + +// get is the method called to fetch an entry from the cache. +func (o *Cache) get( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*tplCacheEntry, error) { + now := o.now() + ent, ok := o.getEntryAndTouchAt(typ, now) + if !ok { + // Cache miss, straight to load with no current entry. + return o.load(ctx, cfg, typ, nil) + } + + maxAge := cfg.Mailer.TemplateMaxAge + if now.Sub(ent.createdAt) < maxAge { + // Cache hit and the entry is not expired, return it. + return ent, nil + } + + // Entry is expired, we check if the entry is ready for reloading. We do + // as much as we can outside of load to prevent synchronization on o.sf. + retryIval := cfg.Mailer.TemplateRetryInterval + if now.Sub(ent.checkedAt) < retryIval { + // Entry was checked within maxIval, return it. + return ent, nil + } + + // Call load with our most recent entry. + return o.load(ctx, cfg, typ, ent) +} + +// load is what happens when "get" has a cache miss, the hit has expired or +// the a previously failed check has elapsed the ival. +func (o *Cache) load( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, + cur *tplCacheEntry, +) (*tplCacheEntry, error) { + + // Before load returns, forget the most recent result of sf. Because we + // write our cache result in Do we guarantee that the next call to SF + // after this function returns will be a cache hit. + defer o.sf.Forget(typ) + + // We prevent a recently restarted auth server from sending multiple + // concurrent requests to the templating endpoint with pkg singleflight. + v, err, _ := o.sf.Do(typ, func() (any, error) { + + // First try to load a fresh entry. + ent, err := o.loadEntry(ctx, cfg, typ) + if err == nil { + // No error fetching fresh entry, put in cache & return it. + o.putEntry(typ, ent) + return ent, nil + } + + // We had an err loading a fresh entry. Check if we had a current entry + // and return a copy of that with a last checked time. + if cur != nil { + cpy := cur.copy() + cpy.checkedAt = o.now() + + o.putEntry(typ, cpy) + return cpy, nil + } + + // We have no previous entry and no fresh entry, we will load the + // default templates so the mailer can continue serving requests. + ent = o.loadEntryDefault(typ) + o.putEntry(typ, ent) + return ent, nil + }) + if err != nil { + // I don't believe SF returns an error unless the fn it calls does, so + // this is mostly a defensive check. + err = wrapError(ctx, typ, "internal_error", err) + return nil, err + } + + // v is always a *tplCacheEntry + return v.(*tplCacheEntry), nil +} + +// loadEntry returns the +func (o *Cache) loadEntry( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*tplCacheEntry, error) { + subjectTemp, err := o.loadEntrySubject(ctx, cfg, typ) + if err != nil { + return nil, err + } + + bodyTemp, err := o.loadEntryBody(ctx, cfg, typ) + if err != nil { + return nil, err + } + + now := o.now() + ent := newTplCacheEntry(now, typ, subjectTemp, bodyTemp) + return ent, nil +} + +// loadEntryDefault will never fail due to the checkDefaults() in init(). +func (o *Cache) loadEntryDefault( + typ string, +) *tplCacheEntry { + subjectStr := getEmailContentConfig(defaultTemplateSubjects, typ, "") + subjectTemp := template.Must(template.New("").Parse(subjectStr)) + + bodyStr := getEmailContentConfig(defaultTemplateBodies, typ, "") + bodyTemp := template.Must(template.New("").Parse(bodyStr)) + + now := o.now() + ent := newTplCacheEntry(now, typ, subjectTemp, bodyTemp) + ent.def = true + return ent +} + +func (o *Cache) loadEntrySubject( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*template.Template, error) { + + // This matches the existing behavior, which allow for a potential double + // parse of the default but it's a minor cost for clean control flow. + tempStr := getEmailContentConfig( + &cfg.Mailer.Subjects, + typ, + getEmailContentConfig(defaultTemplateSubjects, typ, "")) + + temp, err := template.New("Subject").Parse(tempStr) + if err != nil { + err = wrapError(ctx, typ, "template_subject_parse_error", err) + return nil, err + } + return temp, nil +} + +func (o *Cache) loadEntryBody( + ctx context.Context, + cfg *conf.GlobalConfiguration, + typ string, +) (*template.Template, error) { + url := getEmailContentConfig(&cfg.Mailer.Templates, typ, "") + if url == "" { + + // We preserve the previous behavior of returning the default. + tempStr := getEmailContentConfig(defaultTemplateBodies, typ, "") + temp := template.Must(template.New("").Parse(tempStr)) + return temp, nil + } + if !strings.HasPrefix(url, "http") { + url = cfg.SiteURL + url + } + + tempStr, err := o.fetch(ctx, cfg, url) + if err != nil { + err = wrapError(ctx, typ, "template_body_http_error", err) + return nil, err + } + + temp, err := template.New(url).Parse(tempStr) + if err != nil { + err = wrapError(ctx, typ, "template_body_parse_error", err) + return nil, err + } + return temp, nil +} + +func (m *Cache) fetch(ctx context.Context, cfg *conf.GlobalConfiguration, url string) (string, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + + if code := res.StatusCode; code != http.StatusOK { + return "", fmt.Errorf("http: GET %v: status code %d", url, code) + } + + rdr := io.LimitReader(res.Body, int64(cfg.Mailer.TemplateMaxSize)) + data, err := io.ReadAll(rdr) + if err != nil { + return "", err + } + + body := string(data) + return body, nil +} + +func wrapError(ctx context.Context, typ, label string, err error) error { + if err == nil { + return nil + } + + err = fmt.Errorf( + "templatemailer: template type %q: %w", typ, err) + le := observability.GetLogEntryFromContext(ctx).Entry + le.WithFields(logrus.Fields{ + "event": "templatemailer_" + label, + "mail_type": typ, + }).WithError(err).Error(err) + return err +} + +func lookupEmailContentConfig( + cfg *conf.EmailContentConfiguration, + tpl string, +) (string, bool) { + switch tpl { + default: + return "", false + case InviteTemplate: + return cfg.Invite, true + case ConfirmationTemplate: + return cfg.Confirmation, true + case RecoveryTemplate: + return cfg.Recovery, true + case EmailChangeTemplate: + return cfg.EmailChange, true + case ReauthenticationTemplate: + return cfg.Reauthentication, true + case MagicLinkTemplate: + return cfg.MagicLink, true + } +} + +func getEmailContentConfig( + cfg *conf.EmailContentConfiguration, + tpl string, + def string, +) string { + // This matches behavior of old withDefault ("" != v) + if v, ok := lookupEmailContentConfig(cfg, tpl); ok && v != "" { + return v + } + return def +} + +func checkDefaults() error { + seen := make(map[string]bool) + data := map[string]any{ + "ConfirmationURL": "ConfirmationURL", + "Data": "Data", + "Email": "Email", + "NewEmail": "NewEmail", + "RedirectTo": "RedirectTo", + "SendingTo": "SendingTo", + "SiteURL": "SiteURL", + "Token": "Token", + "TokenHash": "TokenHash", + } + + buf := new(bytes.Buffer) + check := func(cfg *conf.EmailContentConfiguration, typ string) error { + defer buf.Reset() + + tempStr, ok := lookupEmailContentConfig(cfg, typ) + if !ok { + return fmt.Errorf( + "templatemailer: template type %q: missing default body template", typ) + } + + temp, err := template.New(typ).Parse(tempStr) + if err != nil { + return err + } + + if err := temp.Execute(buf, data); err != nil { + return err + } + return nil + } + + for _, typ := range templateTypes { + if seen[typ] { + return fmt.Errorf( + "templatemailer: template type %q: duplicate found", typ) + } + seen[typ] = true + + if err := check(defaultTemplateSubjects, typ); err != nil { + return err + } + if err := check(defaultTemplateBodies, typ); err != nil { + return err + } + } + return nil +} diff --git a/internal/mailer/template.go b/internal/mailer/templatemailer/templatemailer.go similarity index 55% rename from internal/mailer/template.go rename to internal/mailer/templatemailer/templatemailer.go index b6b800100..c3c96eccd 100644 --- a/internal/mailer/template.go +++ b/internal/mailer/templatemailer/templatemailer.go @@ -1,4 +1,4 @@ -package mailer +package templatemailer import ( "context" @@ -11,58 +11,13 @@ import ( "github.com/supabase/auth/internal/models" ) -type MailRequest struct { - To string - SubjectTemplate string - TemplateURL string - DefaultTemplate string - TemplateData map[string]interface{} - Headers map[string][]string - Type string -} - -type MailClient interface { - Mail( - ctx context.Context, - to string, - subjectTemplate string, - templateURL string, - defaultTemplate string, - templateData map[string]interface{}, - headers map[string][]string, - typ string, - ) error -} - -// TemplateMailer will send mail and use templates from the site for easy mail styling -type TemplateMailer struct { - SiteURL string - Config *conf.GlobalConfiguration - MailClient MailClient -} - -func encodeRedirectURL(referrerURL string) string { - if len(referrerURL) > 0 { - if strings.ContainsAny(referrerURL, "&=#") { - // if the string contains &, = or # it has not been URL - // encoded by the caller, which means it should be URL - // encoded by us otherwise, it should be taken as-is - referrerURL = url.QueryEscape(referrerURL) - } - } - return referrerURL -} - const ( - SignupVerification = "signup" - RecoveryVerification = "recovery" - InviteVerification = "invite" - MagicLinkVerification = "magiclink" - EmailChangeVerification = "email_change" - EmailOTPVerification = "email" - EmailChangeCurrentVerification = "email_change_current" - EmailChangeNewVerification = "email_change_new" - ReauthenticationVerification = "reauthentication" + InviteTemplate = "invite" + ConfirmationTemplate = "confirmation" + RecoveryTemplate = "recovery" + EmailChangeTemplate = "email_change" + MagicLinkTemplate = "magic_link" + ReauthenticationTemplate = "reauthentication" ) const defaultInviteMail = `

You have been invited

@@ -71,7 +26,7 @@ const defaultInviteMail = `

You have been invited

Accept the invite

Alternatively, enter the code: {{ .Token }}

` -const defaultConfirmationMail = `

Confirm your email

+const defaultConfirmationMail = `

Confirm Your Email

Follow this link to confirm your email:

Confirm your email address

@@ -100,8 +55,35 @@ const defaultReauthenticateMail = `

Confirm reauthentication

Enter the code: {{ .Token }}

` -func (m *TemplateMailer) Headers(messageType string) map[string][]string { - originalHeaders := m.Config.SMTP.NormalizedHeaders() +var ( + templateTypes = []string{ + InviteTemplate, + ConfirmationTemplate, + RecoveryTemplate, + EmailChangeTemplate, + MagicLinkTemplate, + ReauthenticationTemplate, + } + defaultTemplateSubjects = &conf.EmailContentConfiguration{ + Invite: "You have been invited", + Confirmation: "Confirm Your Email", + Recovery: "Reset Your Password", + MagicLink: "Your Magic Link", + EmailChange: "Confirm Email Change", + Reauthentication: "Confirm reauthentication", + } + defaultTemplateBodies = &conf.EmailContentConfiguration{ + Invite: defaultInviteMail, + Confirmation: defaultConfirmationMail, + Recovery: defaultRecoveryMail, + MagicLink: defaultMagicLinkMail, + EmailChange: defaultEmailChangeMail, + Reauthentication: defaultReauthenticateMail, + } +) + +func (m *Mailer) Headers(cfg *conf.GlobalConfiguration, messageType string) map[string][]string { + originalHeaders := cfg.SMTP.NormalizedHeaders() if originalHeaders == nil { return nil @@ -136,8 +118,8 @@ func (m *TemplateMailer) Headers(messageType string) map[string][]string { } // InviteMail sends a invite mail to a new user -func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ +func (m *Mailer) InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Invite, &emailParams{ Token: user.ConfirmationToken, Type: "invite", RedirectTo: referrerURL, @@ -147,8 +129,8 @@ func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, ref return err } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, + data := map[string]any{ + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -156,22 +138,12 @@ func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, ref "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.MailClient.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Invite, "You have been invited"), - m.Config.Mailer.Templates.Invite, - defaultInviteMail, - data, - m.Headers("invite"), - "invite", - ) + return m.mail(r.Context(), m.cfg, InviteTemplate, user.GetEmail(), data) } // ConfirmationMail sends a signup confirmation mail to a new user -func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ +func (m *Mailer) ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Confirmation, &emailParams{ Token: user.ConfirmationToken, Type: "signup", RedirectTo: referrerURL, @@ -180,8 +152,8 @@ func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, ot return err } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, + data := map[string]any{ + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -189,67 +161,42 @@ func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, ot "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.MailClient.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Your Email"), - m.Config.Mailer.Templates.Confirmation, - defaultConfirmationMail, - data, - m.Headers("confirm"), - "confirm", - ) + return m.mail(r.Context(), m.cfg, ConfirmationTemplate, user.GetEmail(), data) } // ReauthenticateMail sends a reauthentication mail to an authenticated user -func (m *TemplateMailer) ReauthenticateMail(r *http.Request, user *models.User, otp string) error { - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, +func (m *Mailer) ReauthenticateMail(r *http.Request, user *models.User, otp string) error { + data := map[string]any{ + "SiteURL": m.cfg.SiteURL, "Email": user.Email, "Token": otp, "Data": user.UserMetaData, } - - return m.MailClient.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Reauthentication, "Confirm reauthentication"), - m.Config.Mailer.Templates.Reauthentication, - defaultReauthenticateMail, - data, - m.Headers("reauthenticate"), - "reauthenticate", - ) + return m.mail(r.Context(), m.cfg, ReauthenticationTemplate, user.GetEmail(), data) } // EmailChangeMail sends an email change confirmation mail to a user -func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error { +func (m *Mailer) EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error { type Email struct { + Action string Address string Otp string TokenHash string - Subject string - Template string } emails := []Email{ { Address: user.EmailChange, Otp: otpNew, TokenHash: user.EmailChangeTokenNew, - Subject: withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), - Template: m.Config.Mailer.Templates.EmailChange, }, } currentEmail := user.GetEmail() - if m.Config.Mailer.SecureEmailChangeEnabled && currentEmail != "" { + if m.cfg.Mailer.SecureEmailChangeEnabled && currentEmail != "" { emails = append(emails, Email{ Address: currentEmail, Otp: otpCurrent, TokenHash: user.EmailChangeTokenCurrent, - Subject: withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Email Address"), - Template: m.Config.Mailer.Templates.EmailChange, }) } @@ -259,8 +206,8 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp errors := make(chan error, len(emails)) for _, email := range emails { path, err := getPath( - m.Config.Mailer.URLPaths.EmailChange, - &EmailParams{ + m.cfg.Mailer.URLPaths.EmailChange, + &emailParams{ Token: email.TokenHash, Type: "email_change", RedirectTo: referrerURL, @@ -269,9 +216,9 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp if err != nil { return err } - go func(address, token, tokenHash, template string) { - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, + go func(address, token, tokenHash string) { + data := map[string]any{ + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.GetEmail(), "NewEmail": user.EmailChange, @@ -281,17 +228,14 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp "Data": user.UserMetaData, "RedirectTo": referrerURL, } - errors <- m.MailClient.Mail( + errors <- m.mail( ctx, + m.cfg, + EmailChangeTemplate, address, - withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), - template, - defaultEmailChangeMail, data, - m.Headers("email_change"), - "email_change", ) - }(email.Address, email.Otp, email.TokenHash, email.Template) + }(email.Address, email.Otp, email.TokenHash) } for i := 0; i < len(emails); i++ { @@ -304,8 +248,8 @@ func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otp } // RecoveryMail sends a password recovery mail -func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ +func (m *Mailer) RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "recovery", RedirectTo: referrerURL, @@ -313,8 +257,8 @@ func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, r if err != nil { return err } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, + data := map[string]any{ + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -322,22 +266,12 @@ func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, r "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.MailClient.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.Recovery, "Reset Your Password"), - m.Config.Mailer.Templates.Recovery, - defaultRecoveryMail, - data, - m.Headers("recovery"), - "recovery", - ) + return m.mail(r.Context(), m.cfg, RecoveryTemplate, user.GetEmail(), data) } // MagicLinkMail sends a login link mail -func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { - path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ +func (m *Mailer) MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "magiclink", RedirectTo: referrerURL, @@ -346,8 +280,8 @@ func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, return err } - data := map[string]interface{}{ - "SiteURL": m.Config.SiteURL, + data := map[string]any{ + "SiteURL": m.cfg.SiteURL, "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, @@ -355,57 +289,47 @@ func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, "Data": user.UserMetaData, "RedirectTo": referrerURL, } - - return m.MailClient.Mail( - r.Context(), - user.GetEmail(), - withDefault(m.Config.Mailer.Subjects.MagicLink, "Your Magic Link"), - m.Config.Mailer.Templates.MagicLink, - defaultMagicLinkMail, - data, - m.Headers("magiclink"), - "magiclink", - ) + return m.mail(r.Context(), m.cfg, MagicLinkTemplate, user.GetEmail(), data) } // GetEmailActionLink returns a magiclink, recovery or invite link based on the actionType passed. -func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) { +func (m *Mailer) GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) { var err error var path *url.URL switch actionType { case "magiclink": - path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "magiclink", RedirectTo: referrerURL, }) case "recovery": - path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Recovery, &emailParams{ Token: user.RecoveryToken, Type: "recovery", RedirectTo: referrerURL, }) case "invite": - path, err = getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Invite, &emailParams{ Token: user.ConfirmationToken, Type: "invite", RedirectTo: referrerURL, }) case "signup": - path, err = getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.Confirmation, &emailParams{ Token: user.ConfirmationToken, Type: "signup", RedirectTo: referrerURL, }) case "email_change_current": - path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.EmailChange, &emailParams{ Token: user.EmailChangeTokenCurrent, Type: "email_change", RedirectTo: referrerURL, }) case "email_change_new": - path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + path, err = getPath(m.cfg.Mailer.URLPaths.EmailChange, &emailParams{ Token: user.EmailChangeTokenNew, Type: "email_change", RedirectTo: referrerURL, @@ -418,3 +342,36 @@ func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referr } return externalURL.ResolveReference(path).String(), nil } + +type emailParams struct { + Token string + Type string + RedirectTo string +} + +func getPath(filepath string, params *emailParams) (*url.URL, error) { + path := &url.URL{} + if filepath != "" { + if p, err := url.Parse(filepath); err != nil { + return nil, err + } else { + path = p + } + } + if params != nil { + path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) + } + return path, nil +} + +func encodeRedirectURL(referrerURL string) string { + if len(referrerURL) > 0 { + if strings.ContainsAny(referrerURL, "&=#") { + // if the string contains &, = or # it has not been URL + // encoded by the caller, which means it should be URL + // encoded by us otherwise, it should be taken as-is + referrerURL = url.QueryEscape(referrerURL) + } + } + return referrerURL +} diff --git a/internal/mailer/template_test.go b/internal/mailer/templatemailer/templatemailer_test.go similarity index 85% rename from internal/mailer/template_test.go rename to internal/mailer/templatemailer/templatemailer_test.go index f8fcd7417..22837005f 100644 --- a/internal/mailer/template_test.go +++ b/internal/mailer/templatemailer/templatemailer_test.go @@ -1,4 +1,4 @@ -package mailer +package templatemailer import ( "testing" @@ -50,16 +50,15 @@ func TestTemplateHeaders(t *testing.T) { }, } for _, tc := range cases { - mailer := TemplateMailer{ - Config: &conf.GlobalConfiguration{ - SMTP: conf.SMTPConfiguration{ - Headers: tc.from, - }, + mailer := New(&conf.GlobalConfiguration{ + SMTP: conf.SMTPConfiguration{ + Headers: tc.from, }, - } - require.NoError(t, mailer.Config.SMTP.Validate()) + }, nil, nil) - hdrs := mailer.Headers(tc.typ) + require.NoError(t, mailer.cfg.SMTP.Validate()) + + hdrs := mailer.Headers(mailer.cfg, tc.typ) require.Equal(t, hdrs, tc.exp) } } diff --git a/internal/mailer/mailer_test.go b/internal/mailer/templatemailer/templatemailer_url_test.go similarity index 96% rename from internal/mailer/mailer_test.go rename to internal/mailer/templatemailer/templatemailer_url_test.go index 290d65dd0..672d37130 100644 --- a/internal/mailer/mailer_test.go +++ b/internal/mailer/templatemailer/templatemailer_url_test.go @@ -1,4 +1,4 @@ -package mailer +package templatemailer import ( "net/url" @@ -15,7 +15,7 @@ func enforceRelativeURL(url string) string { } func TestGetPath(t *testing.T) { - params := EmailParams{ + params := emailParams{ Token: "token", Type: "signup", RedirectTo: "https://example.com", @@ -23,7 +23,7 @@ func TestGetPath(t *testing.T) { cases := []struct { SiteURL string Path string - Params *EmailParams + Params *emailParams Expected string }{ { diff --git a/internal/mailer/validate.go b/internal/mailer/validateclient/validateclient.go similarity index 79% rename from internal/mailer/validate.go rename to internal/mailer/validateclient/validateclient.go index b8231f26e..b194bf111 100644 --- a/internal/mailer/validate.go +++ b/internal/mailer/validateclient/validateclient.go @@ -1,4 +1,4 @@ -package mailer +package validateclient import ( "bytes" @@ -13,6 +13,7 @@ import ( "time" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" "golang.org/x/sync/errgroup" ) @@ -77,15 +78,75 @@ var ( ErrInvalidEmailMX = errors.New("invalid_email_mx") ) -type EmailValidator struct { +// New will return a Client that first calls an email validator before passing +// the mail along to given Client. If email validation is disabled then it +// returns the same Client passed in mc. +func New(globalConfig *conf.GlobalConfiguration, mc mailer.Client) mailer.Client { + + // Check if email validation is enabled + ev := newEmailValidator(globalConfig.Mailer) + if ev.isEnabled() { + mc = &emailValidatorMailClient{ev: ev, mc: mc} + } + return mc +} + +type emailValidatorMailClient struct { + ev *emailValidator + mc mailer.Client +} + +// Mail implements mailer.MailClient interface by calling validate before +// passing the mail request to the next MailClient. +func (o *emailValidatorMailClient) Mail( + ctx context.Context, + to string, + subject string, + body string, + headers map[string][]string, + typ string, +) error { + if err := o.ev.Validate(ctx, to); err != nil { + return err + } + return o.mc.Mail( + ctx, + to, + subject, + body, + headers, + typ, + ) +} + +type emailValidator struct { extended bool serviceURL string serviceHeaders map[string][]string blockedMXRecords map[string]bool } -func newEmailValidator(mc conf.MailerConfiguration) *EmailValidator { - return &EmailValidator{ +func (m *emailValidator) MailNew( + ctx context.Context, + to, subject, body string, + headers map[string][]string, + typ string, +) error { + return nil +} + +func (m *emailValidator) Mail( + ctx context.Context, + to, subjectTemplate, templateURL, defaultTemplate string, + templateData map[string]any, + headers map[string][]string, + typ string, +) error { + return nil +} + +func newEmailValidator(mc conf.MailerConfiguration) *emailValidator { + return &emailValidator{ extended: mc.EmailValidationExtended, serviceURL: mc.EmailValidationServiceURL, serviceHeaders: mc.GetEmailValidationServiceHeaders(), @@ -93,12 +154,12 @@ func newEmailValidator(mc conf.MailerConfiguration) *EmailValidator { } } -func (ev *EmailValidator) isEnabled() bool { +func (ev *emailValidator) isEnabled() bool { return ev.isExtendedEnabled() || ev.isServiceEnabled() } -func (ev *EmailValidator) isExtendedEnabled() bool { return ev.extended } -func (ev *EmailValidator) isServiceEnabled() bool { return ev.serviceURL != "" } +func (ev *emailValidator) isExtendedEnabled() bool { return ev.extended } +func (ev *emailValidator) isServiceEnabled() bool { return ev.serviceURL != "" } // Validate performs validation on the given email. // @@ -108,7 +169,7 @@ func (ev *EmailValidator) isServiceEnabled() bool { return ev.serviceURL != "" // // When serviceURL AND serviceKey are non-empty strings it uses the remote // service to determine if the email is valid. -func (ev *EmailValidator) Validate(ctx context.Context, email string) error { +func (ev *emailValidator) Validate(ctx context.Context, email string) error { if !ev.isEnabled() { return nil } @@ -151,7 +212,7 @@ func (ev *EmailValidator) Validate(ctx context.Context, email string) error { // validateStatic will validate the format and do the static checks before // returning the host portion of the email. -func (ev *EmailValidator) validateStatic(email string) (string, error) { +func (ev *emailValidator) validateStatic(email string) (string, error) { if !ev.isExtendedEnabled() { return "", nil } @@ -189,7 +250,7 @@ func (ev *EmailValidator) validateStatic(email string) (string, error) { return host, nil } -func (ev *EmailValidator) validateService(ctx context.Context, email string) error { +func (ev *emailValidator) validateService(ctx context.Context, email string) error { if !ev.isServiceEnabled() { return nil } @@ -244,7 +305,7 @@ func (ev *EmailValidator) validateService(ctx context.Context, email string) err return ErrInvalidEmailAddress } -func (ev *EmailValidator) validateProviders(name, host string) error { +func (ev *emailValidator) validateProviders(name, host string) error { switch host { case "gmail.com": // Based on a sample of internal data, this reduces the number of @@ -259,7 +320,7 @@ func (ev *EmailValidator) validateProviders(name, host string) error { return nil } -func (ev *EmailValidator) validateHost(ctx context.Context, host string) error { +func (ev *emailValidator) validateHost(ctx context.Context, host string) error { mxs, err := validateEmailResolver.LookupMX(ctx, host) if !isHostNotFound(err) { return ev.validateMXRecords(mxs, nil) @@ -274,7 +335,7 @@ func (ev *EmailValidator) validateHost(ctx context.Context, host string) error { return ErrInvalidEmailDNS } -func (ev *EmailValidator) validateMXRecords(mxs []*net.MX, hosts []string) error { +func (ev *emailValidator) validateMXRecords(mxs []*net.MX, hosts []string) error { for _, mx := range mxs { if ev.blockedMXRecords[mx.Host] { return ErrInvalidEmailMX diff --git a/internal/mailer/validate_test.go b/internal/mailer/validateclient/validateclient_test.go similarity index 99% rename from internal/mailer/validate_test.go rename to internal/mailer/validateclient/validateclient_test.go index e9bae8fef..43e2fbc11 100644 --- a/internal/mailer/validate_test.go +++ b/internal/mailer/validateclient/validateclient_test.go @@ -1,4 +1,4 @@ -package mailer +package validateclient import ( "context"