Skip to content

Commit

Permalink
feat(pgdriver): allow specifying timeout in DSN
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Sep 25, 2021
1 parent afc02cc commit 7dbc71b
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 38 deletions.
130 changes: 96 additions & 34 deletions driver/pgdriver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net"
"net/url"
"os"
"sort"
"strconv"
"strings"
"time"
)
Expand All @@ -33,11 +35,9 @@ type Config struct {
Database string
AppName string

// Timeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking.
// Timeout for socket reads. If reached, commands fail with a timeout instead of blocking.
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking.
// Timeout for socket writes. If reached, commands fail with a timeout instead of blocking.
WriteTimeout time.Duration
}

Expand Down Expand Up @@ -153,6 +153,15 @@ func WithDSN(dsn string) DriverOption {
}
}

func env(key, defValue string) string {
if s := os.Getenv(key); s != "" {
return s
}
return defValue
}

//------------------------------------------------------------------------------

func parseDSN(dsn string) ([]DriverOption, error) {
u, err := url.Parse(dsn)
if err != nil {
Expand All @@ -163,11 +172,6 @@ func parseDSN(dsn string) ([]DriverOption, error) {
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
}

query, err := url.ParseQuery(u.RawQuery)
if err != nil {
return nil, err
}

var opts []DriverOption

if u.Host != "" {
Expand All @@ -187,39 +191,45 @@ func parseDSN(dsn string) ([]DriverOption, error) {
opts = append(opts, WithDatabase(u.Path[1:]))
}

if appName := query.Get("application_name"); appName != "" {
q := queryOptions{q: u.Query()}

if appName := q.string("application_name"); appName != "" {
opts = append(opts, WithApplicationName(appName))
}
delete(query, "application_name")

if sslMode := query.Get("sslmode"); sslMode != "" {
switch sslMode {
case "verify-ca", "verify-full":
opts = append(opts, WithTLSConfig(new(tls.Config)))
case "allow", "prefer", "require":
opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
case "disable":
// no TLS config
default:
return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
}
} else {

switch sslMode := q.string("sslmode"); sslMode {
case "verify-ca", "verify-full":
opts = append(opts, WithTLSConfig(new(tls.Config)))
case "allow", "prefer", "require":
opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
case "disable", "":
// no TLS config
default:
return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
}
delete(query, "sslmode")

for key := range query {
return nil, fmt.Errorf("pgdriver: unsupported option=%q", key)
if d := q.duration("timeout"); d != 0 {
opts = append(opts, WithTimeout(d))
}
if d := q.duration("dial_timeout"); d != 0 {
opts = append(opts, WithDialTimeout(d))
}
if d := q.duration("read_timeout"); d != 0 {
opts = append(opts, WithReadTimeout(d))
}
if d := q.duration("write_timeout"); d != 0 {
opts = append(opts, WithWriteTimeout(d))
}

return opts, nil
}

func env(key, defValue string) string {
if s := os.Getenv(key); s != "" {
return s
rem, err := q.remaining()
if err != nil {
return nil, q.err
}
return defValue
if len(rem) > 0 {
return nil, fmt.Errorf("pgdriver: unexpected option: %s", strings.Join(rem, ", "))
}

return opts, nil
}

// verify is a method to make sure if the config is legitimate
Expand All @@ -231,3 +241,55 @@ func (c *Config) verify() error {
}
return nil
}

type queryOptions struct {
q url.Values
err error
}

func (o *queryOptions) string(name string) string {
vs := o.q[name]
if len(vs) == 0 {
return ""
}
delete(o.q, name) // enable detection of unknown parameters
return vs[len(vs)-1]
}

func (o *queryOptions) duration(name string) time.Duration {
s := o.string(name)
if s == "" {
return 0
}
// try plain number first
if i, err := strconv.Atoi(s); err == nil {
if i <= 0 {
// disable timeouts
return -1
}
return time.Duration(i) * time.Second
}
dur, err := time.ParseDuration(s)
if err == nil {
return dur
}
if o.err == nil {
o.err = fmt.Errorf("pgdriver: invalid %s duration: %w", name, err)
}
return 0
}

func (o *queryOptions) remaining() ([]string, error) {
if o.err != nil {
return nil, o.err
}
if len(o.q) == 0 {
return nil, nil
}
keys := make([]string, 0, len(o.q))
for k := range o.q {
keys = append(keys, k)
}
sort.Strings(keys)
return keys, nil
}
13 changes: 13 additions & 0 deletions driver/pgdriver/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ func TestParseDSN(t *testing.T) {
WriteTimeout: 5 * time.Second,
},
},
{
dsn: "postgres://postgres:1@localhost:5432/testDatabase?sslmode=disable&dial_timeout=1&read_timeout=2s&write_timeout=3",
cfg: &pgdriver.Config{
Network: "tcp",
Addr: "localhost:5432",
User: "postgres",
Password: "1",
Database: "testDatabase",
DialTimeout: 1 * time.Second,
ReadTimeout: 2 * time.Second,
WriteTimeout: 3 * time.Second,
},
},
{
dsn: "postgres://postgres:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable",
cfg: &pgdriver.Config{
Expand Down
6 changes: 2 additions & 4 deletions example/pg-faceted-search/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,15 @@ func main() {
panic(err)
}

fmt.Println("\n")
fmt.Println("all facets:\n")
fmt.Printf("\n\nall facets:\n\n")
spew.Dump(facets)

facets, err = selectFacets(ctx, db, "moods:mysterious")
if err != nil {
panic(err)
}

fmt.Println("\n")
fmt.Println("moods:mysterious facets:\n")
fmt.Printf("\n\nmoods:mysterious facets:\n\n")
spew.Dump(facets)
}

Expand Down

0 comments on commit 7dbc71b

Please sign in to comment.