diff --git a/driver/registry_sql.go b/driver/registry_sql.go index 16f9b238e..a7163cd8d 100644 --- a/driver/registry_sql.go +++ b/driver/registry_sql.go @@ -7,7 +7,6 @@ import ( "github.com/ory/x/dbal" "github.com/ory/x/sqlcon" - "github.com/ory/x/urlx" "github.com/ory/keto/storage" ) @@ -76,7 +75,7 @@ func (m *RegistrySQL) StorageManager() storage.Manager { } func (m *RegistrySQL) CanHandle(dsn string) bool { - s := dbal.Canonicalize(urlx.ParseOrFatal(m.l, dsn).Scheme) + s := sqlcon.GetDriverName(dsn) return s == dbal.DriverMySQL || s == dbal.DriverPostgreSQL } diff --git a/driver/registry_sql_test.go b/driver/registry_sql_test.go new file mode 100644 index 000000000..662e907ae --- /dev/null +++ b/driver/registry_sql_test.go @@ -0,0 +1,24 @@ +package driver + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRegistrySQL_CanHandle(t *testing.T) { + for k, tc := range []struct { + dsn string + expected bool + }{ + {dsn: "memory"}, + {dsn: "mysql://foo:bar@tcp(baz:1234)/db?foo=bar", expected: true}, + {dsn: "postgres://foo:bar@baz:1234/db?foo=bar", expected: true}, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + r := RegistrySQL{} + assert.Equal(t, tc.expected, r.CanHandle(tc.dsn)) + }) + } +}