Skip to content

Commit

Permalink
Making high level API work properly with azuresql URLs
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed Aug 16, 2023
1 parent d794b30 commit 34fab66
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 85 deletions.
20 changes: 14 additions & 6 deletions dburl.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ func Open(urlstr string) (*sql.DB, error) {
if err != nil {
return nil, err
}
return sql.Open(u.Driver, u.DSN)
driver := u.Driver
if u.GoDriver != "" {
driver = u.GoDriver
}
return sql.Open(driver, u.DSN)
}

// URL wraps the standard [net/url.URL] type, adding OriginalScheme, Transport,
Expand All @@ -40,8 +44,12 @@ type URL struct {
// Driver is the non-aliased SQL driver name that should be used in a call
// to sql/Open.
Driver string
// Unaliased is the unaliased driver name.
Unaliased string
// GoDriver is the Go SQL driver name to use when opening a connection to
// the database. Used by Microsoft SQL Server's azuresql URLs, as the
// wire-compatible alias style uses a different syntax style.
GoDriver string
// UnaliasedDriver is the unaliased driver name.
UnaliasedDriver string
// DSN is the built connection "data source name" that can be used in a
// call to sql/Open.
DSN string
Expand Down Expand Up @@ -123,12 +131,12 @@ func Parse(urlstr string) (*URL, error) {
}
}
// set driver
u.Driver, u.Unaliased = scheme.Driver, scheme.Driver
u.Driver, u.UnaliasedDriver = scheme.Driver, scheme.Driver
if scheme.Override != "" {
u.Driver = scheme.Override
}
// generate dsn
if u.DSN, err = scheme.Generator(u); err != nil {
if u.DSN, u.GoDriver, err = scheme.Generator(u); err != nil {
return nil, err
}
return u, nil
Expand Down Expand Up @@ -185,7 +193,7 @@ func (u *URL) Short() string {
// Normalize returns the driver, host, port, database, and user name of a URL,
// joined with sep, populating blank fields with empty.
func (u *URL) Normalize(sep, empty string, cut int) string {
s := []string{u.Unaliased, "", "", "", ""}
s := []string{u.UnaliasedDriver, "", "", "", ""}
if u.Transport != "tcp" && u.Transport != "unix" {
s[0] += "+" + u.Transport
}
Expand Down
12 changes: 7 additions & 5 deletions dburl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ func TestParse(t *testing.T) {
{`mssql://user:pass@localhost/dbname`, `sqlserver`, `sqlserver://user:pass@localhost/?database=dbname`, ``},
{`mssql://user@localhost/service/dbname`, `sqlserver`, `sqlserver://user@localhost/service?database=dbname`, ``},
{`mssql://user:!234%23$@localhost:1580/dbname`, `sqlserver`, `sqlserver://user:%21234%23$@localhost:1580/?database=dbname`, ``},
{`mssql://user:!234%23$@localhost:1580/service/dbname?fedauth=true`, `sqlserver`, `azuresql://user:%21234%23$@localhost:1580/service?database=dbname&fedauth=true`, ``},
{`azuresql://user:pass@localhost:100/dbname`, `sqlserver`, `azuresql://user:pass@localhost:100/?database=dbname`, ``},
{`sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, `sqlserver`, `azuresql://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, ``},
{`azuresql://xxx.database.windows.net/dbname?fedauth=ActiveDirectoryMSI`, `sqlserver`, `azuresql://xxx.database.windows.net/?database=dbname&fedauth=ActiveDirectoryMSI`, ``},
{`mssql://user:!234%23$@localhost:1580/service/dbname?fedauth=true`, `azuresql`, `sqlserver://user:%21234%23$@localhost:1580/service?database=dbname&fedauth=true`, ``},
{`azuresql://user:pass@localhost:100/dbname`, `azuresql`, `sqlserver://user:pass@localhost:100/?database=dbname`, ``},
{`sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, `azuresql`, `sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, ``},
{`azuresql://xxx.database.windows.net/dbname?fedauth=ActiveDirectoryMSI`, `azuresql`, `sqlserver://xxx.database.windows.net/?database=dbname&fedauth=ActiveDirectoryMSI`, ``},
{
`adodb://Microsoft.ACE.OLEDB.12.0?Extended+Properties=%22Text%3BHDR%3DNO%3BFMT%3DDelimited%22`, `adodb`, // 30
`Data Source=.;Extended Properties="Text;HDR=NO;FMT=Delimited";Provider=Microsoft.ACE.OLEDB.12.0`, ``,
Expand Down Expand Up @@ -214,7 +214,9 @@ func TestParse(t *testing.T) {
switch {
case err != nil:
t.Fatalf("test %d expected no error, got: %v", i, err)
case u.Driver != test.d:
case u.GoDriver != "" && u.GoDriver != test.d:
t.Errorf("test %d expected go driver %q, got: %q", i, test.d, u.GoDriver)
case u.GoDriver == "" && u.Driver != test.d:
t.Errorf("test %d expected driver %q, got: %q", i, test.d, u.Driver)
case u.DSN != test.exp:
_, err := os.Stat(test.path)
Expand Down
Loading

0 comments on commit 34fab66

Please sign in to comment.