Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions cmd/serve_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import (
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/supabase/auth/internal/api"
"github.com/supabase/auth/internal/api/apiworker"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/mailer/templatemailer"
"github.com/supabase/auth/internal/reloader"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
Expand Down Expand Up @@ -48,22 +50,24 @@ func serve(ctx context.Context) {
}
defer db.Close()

addr := net.JoinHostPort(config.API.Host, config.API.Port)

opts := []api.Option{
api.NewLimiterOptions(config),
}

baseCtx, baseCancel := context.WithCancel(context.Background())
defer baseCancel()

var wg sync.WaitGroup
defer wg.Wait() // Do not return to caller until this goroutine is done.

a := api.NewAPIWithVersion(config, db, utilities.Version, opts...)
ah := reloader.NewAtomicHandler(a)
logrus.WithField("version", a.Version()).Infof("GoTrue API started on: %s", addr)
mrCache := templatemailer.NewCache()
limiterOpts := api.NewLimiterOptions(config)
initialAPI := api.NewAPIWithVersion(
config, db, utilities.Version,
limiterOpts,
api.WithMailer(templatemailer.FromConfig(config, mrCache)),
)

addr := net.JoinHostPort(config.API.Host, config.API.Port)
logrus.WithField("version", initialAPI.Version()).Infof("GoTrue API started on: %s", addr)

ah := reloader.NewAtomicHandler(initialAPI)
httpSrv := &http.Server{
Addr: addr,
Handler: ah,
Expand All @@ -74,15 +78,53 @@ func serve(ctx context.Context) {
}
log := logrus.WithField("component", "api")

wrkLog := logrus.WithField("component", "apiworker")
wrk := apiworker.New(config, mrCache, wrkLog)
wg.Add(1)
go func() {
defer wg.Done()

var err error
defer func() {
logFn := wrkLog.Info
if err != nil {
logFn = wrkLog.WithError(err).Error
}
logFn("background apiworker is exiting")
}()

// Work exits when ctx is done as in-flight requests do not depend
// on it. If they do in the future this should be baseCtx instead.
err = wrk.Work(ctx)
}()

if watchDir != "" {
wg.Add(1)
go func() {
defer wg.Done()

fn := func(latestCfg *conf.GlobalConfiguration) {
log.Info("reloading api with new configuration")

// When config is updated we notify the apiworker.
wrk.ReloadConfig(latestCfg)

// Create a new API version with the updated config.
latestAPI := api.NewAPIWithVersion(
latestCfg, db, utilities.Version, opts...)
config, db, utilities.Version,

// Create a new mailer with existing template cache.
api.WithMailer(
templatemailer.FromConfig(config, mrCache),
),

// Persist existing rate limiters.
//
// TODO(cstockton): we should consider updating these, if we
// rely on hot config reloads 100% then rate limiter changes
// won't be picked up.
limiterOpts,
)
ah.Store(latestAPI)
}

Expand Down
27 changes: 11 additions & 16 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/supabase/auth/internal/hooks/hookspgfunc"
"github.com/supabase/auth/internal/hooks/v0hooks"
"github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/mailer/templatemailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/storage"
Expand All @@ -38,11 +39,11 @@ type API struct {
config *conf.GlobalConfiguration
version string

hooksMgr *v0hooks.Manager
hibpClient *hibp.PwnedClient
oauthServer *oauthserver.Server
mailerClientFunc func() mailer.MailClient
tokenService *tokens.Service
hooksMgr *v0hooks.Manager
hibpClient *hibp.PwnedClient
oauthServer *oauthserver.Server
tokenService *tokens.Service
mailer mailer.Mailer

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time
Expand All @@ -53,6 +54,7 @@ type API struct {
func (a *API) GetConfig() *conf.GlobalConfiguration { return a.config }
func (a *API) GetDB() *storage.Connection { return a.db }
func (a *API) GetTokenService() *tokens.Service { return a.tokenService }
func (a *API) Mailer() mailer.Mailer { return a.mailer }

func (a *API) Version() string {
return a.version
Expand Down Expand Up @@ -99,11 +101,6 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
if api.limiterOpts == nil {
api.limiterOpts = NewLimiterOptions(globalConfig)
}
if api.mailerClientFunc == nil {
api.mailerClientFunc = func() mailer.MailClient {
return mailer.NewMailClient(globalConfig)
}
}
if api.hooksMgr == nil {
httpDr := hookshttp.New()
pgfuncDr := hookspgfunc.New(db)
Expand All @@ -114,6 +111,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
if api.tokenService == nil {
api.tokenService = tokens.NewService(globalConfig, api.hooksMgr)
}
if api.mailer == nil {
tc := templatemailer.NewCache()
api.mailer = templatemailer.FromConfig(globalConfig, tc)
}

// Connect token service to API's time function (supports test overrides)
api.tokenService.SetTimeFunc(api.Now)
Expand Down Expand Up @@ -397,12 +398,6 @@ func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error {
})
}

// Mailer returns NewMailer with the current tenant config
func (a *API) Mailer() mailer.Mailer {
config := a.config
return mailer.NewMailerWithClient(config, a.mailerClientFunc())
}

// ServeHTTP implements the http.Handler interface by passing the request along
// to its underlying Handler.
func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down
111 changes: 111 additions & 0 deletions internal/api/apiworker/apiworker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package apiworker

import (
"context"
"errors"
"sync"
"time"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/mailer/templatemailer"
)

// Worker is a simple background worker for async tasks.
type Worker struct {
le *logrus.Entry
tc *templatemailer.Cache

// Notifies worker the cfg has been updated.
cfgCh chan struct{}

// workMu must be held for calls to Work
workMu sync.Mutex

// mu must be held for field access below here
mu sync.Mutex
cfg *conf.GlobalConfiguration
}

// New will return a new *Worker instance.
func New(
cfg *conf.GlobalConfiguration,
tc *templatemailer.Cache,
le *logrus.Entry,
) *Worker {
return &Worker{
le: le,
cfg: cfg,
tc: tc,
cfgCh: make(chan struct{}, 1),
}
}

func (o *Worker) putConfig(cfg *conf.GlobalConfiguration) {
o.mu.Lock()
defer o.mu.Unlock()
o.cfg = cfg
}

func (o *Worker) getConfig() *conf.GlobalConfiguration {
o.mu.Lock()
defer o.mu.Unlock()
return o.cfg
}

// ReloadConfig notifies the worker a new configuration is available.
func (o *Worker) ReloadConfig(cfg *conf.GlobalConfiguration) {
o.putConfig(cfg)

select {
case o.cfgCh <- struct{}{}:
default:
}
}

// Work will periodically reload the templates in the background as long as the
// system remains active.
func (o *Worker) Work(ctx context.Context) error {
if ok := o.workMu.TryLock(); !ok {
return errors.New("apiworker: concurrent calls to Work are invalid")
}
defer o.workMu.Unlock()

le := o.le.WithFields(logrus.Fields{
"worker_type": "apiworker_template_cache",
})
le.Info("apiworker: template cache worker started")
defer le.Info("apiworker: template cache worker exited")

// Reload templates right away on Work.
o.maybeReloadTemplates(ctx, o.getConfig())

ival := func() time.Duration {
return max(time.Second, o.getConfig().Mailer.TemplateRetryInterval/4)
}

tr := time.NewTicker(ival())
defer tr.Stop()

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-o.cfgCh:
tr.Reset(ival())
case <-tr.C:
}

// Either ticker fired or we got a config update.
o.maybeReloadTemplates(ctx, o.getConfig())
}
}

func (o *Worker) maybeReloadTemplates(
ctx context.Context,
cfg *conf.GlobalConfiguration,
) {
if cfg.Mailer.TemplateReloadingEnabled {
o.tc.Reload(ctx, cfg)
}
}
7 changes: 4 additions & 3 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/supabase/auth/internal/hooks/v0hooks"
mail "github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/mailer/validateclient"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"

Expand Down Expand Up @@ -705,9 +706,9 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
}

switch {
case errors.Is(err, mail.ErrInvalidEmailAddress),
errors.Is(err, mail.ErrInvalidEmailFormat),
errors.Is(err, mail.ErrInvalidEmailDNS):
case errors.Is(err, validateclient.ErrInvalidEmailAddress),
errors.Is(err, validateclient.ErrInvalidEmailFormat),
errors.Is(err, validateclient.ErrInvalidEmailDNS):
return apierrors.NewBadRequestError(
apierrors.ErrorCodeEmailAddressInvalid,
"Email address %q is invalid",
Expand Down
29 changes: 13 additions & 16 deletions internal/api/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,21 @@ type Option interface {
apply(*API)
}

type MailerOptions struct {
MailerClientFunc func() mailer.MailClient
type optionFunc func(*API)

func (f optionFunc) apply(a *API) { f(a) }

func WithMailer(m mailer.Mailer) Option {
return optionFunc(func(a *API) {
a.mailer = m
})
}

func (mo *MailerOptions) apply(a *API) { a.mailerClientFunc = mo.MailerClientFunc }
func WithTokenService(service *tokens.Service) Option {
return optionFunc(func(a *API) {
a.tokenService = service
})
}

type LimiterOptions struct {
Email ratelimit.Limiter
Expand All @@ -44,19 +54,6 @@ type LimiterOptions struct {

func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }

// TokenServiceOption allows injecting a custom token service
type TokenServiceOption struct {
service *tokens.Service
}

func WithTokenService(service *tokens.Service) *TokenServiceOption {
return &TokenServiceOption{service: service}
}

func (tso *TokenServiceOption) apply(a *API) {
a.tokenService = tso.service
}

func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

Expand Down
17 changes: 0 additions & 17 deletions internal/api/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/e2e"
"github.com/supabase/auth/internal/mailer"
)

func TestNewLimiterOptions(t *testing.T) {
Expand All @@ -31,17 +28,3 @@ func TestNewLimiterOptions(t *testing.T) {
assert.NotNil(t, rl.SSO)
assert.NotNil(t, rl.SAMLAssertion)
}

func TestMailerOptions(t *testing.T) {
globalCfg := e2e.Must(e2e.Config())
conn := e2e.Must(e2e.Conn(globalCfg))

sentinelMailer := mailer.NewMailClient(globalCfg)
mailerOpts := &MailerOptions{MailerClientFunc: func() mailer.MailClient {
return sentinelMailer
}}
a := NewAPIWithVersion(globalCfg, conn, apiTestVersion, mailerOpts)

got := a.mailerClientFunc()
require.Equal(t, sentinelMailer, got)
}
Loading
Loading