Skip to content

Commit 4aa583b

Browse files
author
Edward Muller
committed
Updates based on PR feedback
1 parent 70eddf6 commit 4aa583b

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

options.go

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ type Options struct {
6464
// Enables read only queries on slave nodes.
6565
ReadOnly bool
6666

67-
// Config to use when connecting via TLS
67+
// TLS Config to use. When set TLS will be negotiated.
6868
TLSConfig *tls.Config
6969
}
7070

@@ -74,7 +74,12 @@ func (opt *Options) init() {
7474
}
7575
if opt.Dialer == nil {
7676
opt.Dialer = func() (net.Conn, error) {
77-
return net.DialTimeout(opt.Network, opt.Addr, opt.DialTimeout)
77+
conn, err := net.DialTimeout(opt.Network, opt.Addr, opt.DialTimeout)
78+
if opt.TLSConfig == nil || err != nil {
79+
return conn, err
80+
}
81+
t := tls.Client(conn, opt.TLSConfig)
82+
return t, t.Handshake()
7883
}
7984
}
8085
if opt.PoolSize == 0 {
@@ -142,24 +147,14 @@ func ParseURL(redisURL string) (*Options, error) {
142147
o.DB = 0
143148
case 1:
144149
if o.DB, err = strconv.Atoi(f[0]); err != nil {
145-
return nil, errors.New("Invalid redis database number: " + err.Error())
150+
return nil, errors.New("invalid redis database number: " + err.Error())
146151
}
147152
default:
148153
return nil, errors.New("invalid redis URL path: " + u.Path)
149154
}
150155

151156
if u.Scheme == "rediss" {
152-
o.Dialer = func() (net.Conn, error) {
153-
conn, err := net.DialTimeout(o.Network, o.Addr, o.DialTimeout)
154-
if err != nil {
155-
return nil, err
156-
}
157-
if o.TLSConfig == nil {
158-
o.TLSConfig = &tls.Config{InsecureSkipVerify: true}
159-
}
160-
t := tls.Client(conn, o.TLSConfig)
161-
return t, t.Handshake()
162-
}
157+
o.TLSConfig = &tls.Config{InsecureSkipVerify: true}
163158
}
164159
return o, nil
165160
}

options_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ import (
99

1010
func TestParseURL(t *testing.T) {
1111
cases := []struct {
12-
u string
13-
addr string
14-
db int
15-
dialer bool
16-
err error
12+
u string
13+
addr string
14+
db int
15+
tls bool
16+
err error
1717
}{
1818
{
1919
"redis://localhost:123/1",
@@ -63,31 +63,31 @@ func TestParseURL(t *testing.T) {
6363
{
6464
"redis://localhost/iamadatabase",
6565
"",
66-
0, false, errors.New("Invalid redis database number: strconv.ParseInt: parsing \"iamadatabase\": invalid syntax"),
66+
0, false, errors.New("invalid redis database number: strconv.ParseInt: parsing \"iamadatabase\": invalid syntax"),
6767
},
6868
}
6969

7070
for _, c := range cases {
7171
t.Run(c.u, func(t *testing.T) {
7272
o, err := ParseURL(c.u)
7373
if c.err == nil && err != nil {
74-
t.Fatalf("Expected err to be nil, but got: '%q'", err)
74+
t.Fatalf("unexpected error: '%q'", err)
7575
return
7676
}
7777
if c.err != nil && err != nil {
7878
if c.err.Error() != err.Error() {
79-
t.Fatalf("Expected err to be '%q', but got '%q'", c.err, err)
79+
t.Fatalf("got %q, expected %q", err, c.err)
8080
}
8181
return
8282
}
8383
if o.Addr != c.addr {
84-
t.Errorf("Expected Addr to be '%s', but got '%s'", c.addr, o.Addr)
84+
t.Errorf("got %q, want %q", o.Addr, c.addr)
8585
}
8686
if o.DB != c.db {
87-
t.Errorf("Expecdted DB to be '%d', but got '%d'", c.db, o.DB)
87+
t.Errorf("got %q, expected %q", o.DB, c.db)
8888
}
89-
if c.dialer && o.Dialer == nil {
90-
t.Errorf("Expected a Dialer to be set, but isn't")
89+
if c.tls && o.TLSConfig == nil {
90+
t.Errorf("got nil TLSConfig, expected a TLSConfig")
9191
}
9292
})
9393
}

0 commit comments

Comments
 (0)