diff --git a/internal/driver/config/provider.go b/internal/driver/config/provider.go index d0676fab8..66cb5a105 100644 --- a/internal/driver/config/provider.go +++ b/internal/driver/config/provider.go @@ -146,8 +146,14 @@ func (k *Config) WriteAPIListenOn() string { ) } -func (k *Config) CORS() (cors.Options, bool) { - return k.p.CORS("serve", cors.Options{ +func (k *Config) CORS(iface string) (cors.Options, bool) { + switch iface { + case "read", "write": + default: + panic("expected interface 'read' or 'write', but got unknown interface " + iface) + } + + return k.p.CORS("serve."+iface, cors.Options{ AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"}, AllowedHeaders: []string{"Authorization", "Content-Type"}, ExposedHeaders: []string{"Content-Type"}, diff --git a/internal/driver/registry_default.go b/internal/driver/registry_default.go index cb4dea58f..4a19f5e6a 100644 --- a/internal/driver/registry_default.go +++ b/internal/driver/registry_default.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/ory/x/networkx" + "github.com/rs/cors" "github.com/gobuffalo/pop/v5" "github.com/ory/x/popx" @@ -293,7 +294,13 @@ func (r *RegistryDefault) ReadRouter() http.Handler { n.Use(r.sqaService) } - return n + var handler http.Handler = n + options, enabled := r.Config().CORS("read") + if enabled { + handler = cors.New(options).Handler(handler) + } + + return handler } func (r *RegistryDefault) WriteRouter() http.Handler { @@ -318,7 +325,13 @@ func (r *RegistryDefault) WriteRouter() http.Handler { n.Use(r.sqaService) } - return n + var handler http.Handler = n + options, enabled := r.Config().CORS("write") + if enabled { + handler = cors.New(options).Handler(handler) + } + + return handler } func (r *RegistryDefault) unaryInterceptors() []grpc.UnaryServerInterceptor { diff --git a/internal/e2e/full_suit_test.go b/internal/e2e/full_suit_test.go index 94a4272e3..89c9c40bc 100644 --- a/internal/e2e/full_suit_test.go +++ b/internal/e2e/full_suit_test.go @@ -2,7 +2,11 @@ package e2e import ( "fmt" + "net/http" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/ory/keto/internal/x/dbx" @@ -40,7 +44,7 @@ type ( func Test(t *testing.T) { for _, dsn := range dbx.GetDSNs(t, false) { t.Run(fmt.Sprintf("dsn=%s", dsn.Name), func(t *testing.T) { - ctx, reg, addNamespace := newInitializedReg(t, dsn) + ctx, reg, addNamespace := newInitializedReg(t, dsn, nil) closeServer := startServer(ctx, t, reg) defer closeServer() @@ -76,3 +80,28 @@ func Test(t *testing.T) { }) } } + +func TestServeConfig(t *testing.T) { + ctx, reg, _ := newInitializedReg(t, dbx.GetSqlite(t, dbx.SQLiteMemory), map[string]interface{}{ + "serve.read.cors.enabled": true, + "serve.read.cors.debug": true, + "serve.read.cors.allowed_methods": []string{http.MethodGet}, + "serve.read.cors.allowed_origins": []string{"https://ory.sh"}, + }) + + closeServer := startServer(ctx, t, reg) + defer closeServer() + + for !healthReady(t, "http://"+reg.Config().ReadAPIListenOn()) { + t.Log("Waiting for health check to be ready") + time.Sleep(10 * time.Millisecond) + } + + req, err := http.NewRequest(http.MethodOptions, "http://"+reg.Config().ReadAPIListenOn()+relationtuple.RouteBase, nil) + require.NoError(t, err) + req.Header.Set("Origin", "https://ory.sh") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "https://ory.sh", resp.Header.Get("Access-Control-Allow-Origin"), "%+v", resp.Header) +} diff --git a/internal/e2e/helpers.go b/internal/e2e/helpers.go index 0b0ff2ed1..70c01063b 100644 --- a/internal/e2e/helpers.go +++ b/internal/e2e/helpers.go @@ -26,7 +26,7 @@ import ( "github.com/ory/keto/internal/driver" ) -func newInitializedReg(t testing.TB, dsn *dbx.DsnT) (context.Context, driver.Registry, func(*testing.T, ...*namespace.Namespace)) { +func newInitializedReg(t testing.TB, dsn *dbx.DsnT, cfgOverwrites map[string]interface{}) (context.Context, driver.Registry, func(*testing.T, ...*namespace.Namespace)) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(func() { cancel() @@ -38,7 +38,7 @@ func newInitializedReg(t testing.TB, dsn *dbx.DsnT) (context.Context, driver.Reg flags := pflag.NewFlagSet("", pflag.ContinueOnError) configx.RegisterConfigFlag(flags, nil) - cf := dbx.ConfigFile(t, map[string]interface{}{ + cfgValues := map[string]interface{}{ config.KeyDSN: dsn.Conn, "log.level": "debug", "log.leak_sensitive_values": true, @@ -46,7 +46,12 @@ func newInitializedReg(t testing.TB, dsn *dbx.DsnT) (context.Context, driver.Reg config.KeyReadAPIPort: ports[0], config.KeyWriteAPIHost: "127.0.0.1", config.KeyWriteAPIPort: ports[1], - }) + } + for k, v := range cfgOverwrites { + cfgValues[k] = v + } + + cf := dbx.ConfigFile(t, cfgValues) require.NoError(t, flags.Parse([]string{"--" + configx.FlagConfig, cf})) reg, err := driver.NewDefaultRegistry(ctx, flags, true) diff --git a/internal/e2e/rest_client_test.go b/internal/e2e/rest_client_test.go index e16115ae4..720f3fcf5 100644 --- a/internal/e2e/rest_client_test.go +++ b/internal/e2e/rest_client_test.go @@ -149,18 +149,19 @@ func (rc *restClient) expand(t require.TestingT, r *relationtuple.SubjectSet, de return tree } -func (rc *restClient) waitUntilLive(t require.TestingT) { - var healthReady = func() bool { - req, err := http.NewRequest("GET", rc.readURL+healthx.ReadyCheckPath, nil) - require.NoError(t, err) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false - } - return resp.StatusCode == http.StatusOK +func healthReady(t require.TestingT, readURL string) bool { + req, err := http.NewRequest("GET", readURL+healthx.ReadyCheckPath, nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false } + return resp.StatusCode == http.StatusOK +} + +func (rc *restClient) waitUntilLive(t require.TestingT) { // wait for /health/ready - for !healthReady() { + for !healthReady(t, rc.readURL) { time.Sleep(10 * time.Millisecond) } } diff --git a/internal/x/dbx/dsn_testutils.go b/internal/x/dbx/dsn_testutils.go index f413e52d3..6c146a5dc 100644 --- a/internal/x/dbx/dsn_testutils.go +++ b/internal/x/dbx/dsn_testutils.go @@ -91,6 +91,9 @@ func GetSqlite(t testing.TB, mode sqliteMode) *DsnT { case SQLiteMemory: dsn.Name = "memory" dsn.Conn = fmt.Sprintf("sqlite://file:%s?_fk=true&cache=shared&mode=memory", t.Name()) + t.Cleanup(func() { + _ = os.Remove(t.Name()) + }) case SQLiteFile: t.Cleanup(func() { _ = os.Remove("TestDB.sqlite")