Skip to content

Commit

Permalink
feat(pgdriver): add support for unix socket DSN
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Oct 17, 2021
1 parent 5058064 commit f398cec
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 30 deletions.
47 changes: 33 additions & 14 deletions driver/pgdriver/config.go
Expand Up @@ -187,30 +187,49 @@ func parseDSN(dsn string) ([]Option, error) {
return nil, err
}

if u.Scheme != "postgres" && u.Scheme != "postgresql" {
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
}

q := queryOptions{q: u.Query()}
var opts []Option

if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
switch u.Scheme {
case "postgres", "postgresql":
if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
}
opts = append(opts, WithAddr(addr))
}

if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
}

if host := q.string("host"); host != "" {
opts = append(opts, WithAddr(host))
if host[0] == '/' {
opts = append(opts, WithNetwork("unix"))
}
}
case "unix":
if len(u.Path) == 0 {
return nil, fmt.Errorf("unix socket DSN requires a path: %s", dsn)
}
opts = append(opts, WithAddr(addr))

opts = append(opts, WithNetwork("unix"))
if u.Host != "" {
opts = append(opts, WithDatabase(u.Host))
}
opts = append(opts, WithAddr(u.Path))
default:
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
}

if u.User != nil {
opts = append(opts, WithUser(u.User.Username()))
if password, ok := u.User.Password(); ok {
opts = append(opts, WithPassword(password))
}
}
if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
}

q := queryOptions{q: u.Query()}

if appName := q.string("application_name"); appName != "" {
opts = append(opts, WithApplicationName(appName))
Expand Down
61 changes: 45 additions & 16 deletions driver/pgdriver/config_test.go
@@ -1,6 +1,7 @@
package pgdriver_test

import (
"fmt"
"testing"
"time"

Expand All @@ -16,38 +17,38 @@ func TestParseDSN(t *testing.T) {

tests := []Test{
{
dsn: "postgres://postgres:1@localhost:5432/testDatabase?sslmode=disable",
dsn: "postgres://user:password@localhost:5432/testDatabase?sslmode=disable",
cfg: &pgdriver.Config{
Network: "tcp",
Addr: "localhost:5432",
User: "postgres",
Password: "1",
User: "user",
Password: "password",
Database: "testDatabase",
DialTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
},
},
{
dsn: "postgres://postgres:1@localhost:5432/testDatabase?sslmode=disable&dial_timeout=1&read_timeout=2s&write_timeout=3",
dsn: "postgres://user:password@localhost:5432/testDatabase?sslmode=disable&dial_timeout=1&read_timeout=2s&write_timeout=3",
cfg: &pgdriver.Config{
Network: "tcp",
Addr: "localhost:5432",
User: "postgres",
Password: "1",
User: "user",
Password: "password",
Database: "testDatabase",
DialTimeout: 1 * time.Second,
ReadTimeout: 2 * time.Second,
WriteTimeout: 3 * time.Second,
},
},
{
dsn: "postgres://postgres:1@localhost:5432/testDatabase?search_path=foo",
dsn: "postgres://user:password@localhost:5432/testDatabase?search_path=foo",
cfg: &pgdriver.Config{
Network: "tcp",
Addr: "localhost:5432",
User: "postgres",
Password: "1",
User: "user",
Password: "password",
Database: "testDatabase",
ConnParams: map[string]interface{}{
"search_path": "foo",
Expand All @@ -58,26 +59,54 @@ func TestParseDSN(t *testing.T) {
},
},
{
dsn: "postgres://postgres:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable",
dsn: "postgres://user:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable",
cfg: &pgdriver.Config{
Network: "tcp",
Addr: "app.xxx.us-east-1.rds.amazonaws.com:5432",
User: "postgres",
User: "user",
Password: "password",
Database: "test",
DialTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
},
},
{
dsn: "postgres://user:password@/dbname?host=/var/run/postgresql/.s.PGSQL.5432",
cfg: &pgdriver.Config{
Network: "unix",
Addr: "/var/run/postgresql/.s.PGSQL.5432",
User: "user",
Password: "password",
Database: "dbname",
DialTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
},
},
{
dsn: "unix://user:pass@dbname/var/run/postgresql/.s.PGSQL.5432",
cfg: &pgdriver.Config{
Network: "unix",
Addr: "/var/run/postgresql/.s.PGSQL.5432",
User: "user",
Password: "pass",
Database: "dbname",
DialTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
},
},
}

for _, test := range tests {
c := pgdriver.NewConnector(pgdriver.WithDSN(test.dsn))
for i, test := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
c := pgdriver.NewConnector(pgdriver.WithDSN(test.dsn))

cfg := c.Config()
cfg.Dialer = nil
cfg := c.Config()
cfg.Dialer = nil

require.Equal(t, test.cfg, cfg)
require.Equal(t, test.cfg, cfg)
})
}
}

0 comments on commit f398cec

Please sign in to comment.