Skip to content

Commit

Permalink
dial: add DialContext function
Browse files Browse the repository at this point in the history
In order to replace timeouts with contexts in `Connect` instance
creation (go-tarantool), I need a `DialContext` function.
It accepts context, and cancels, if context is canceled by user.

Part of tarantool/go-tarantool#136
  • Loading branch information
DerekBum committed Sep 28, 2023
1 parent b452431 commit bdadb81
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package openssl

import (
"context"
"errors"
"net"
"time"
Expand Down Expand Up @@ -90,7 +91,19 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
flags DialFlags) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
return dialSession(d, network, addr, ctx, flags, nil)
return dialSession(d, context.Background(), network, addr, ctx, flags, nil)
}

// DialContext acts like Dial but takes a context for network dial.
//
// The context includes only network dial. It does not include OpenSSL calls.
//
// See func Dial for a description of the network, addr, ctx and flags
// parameters.
func DialContext(context context.Context, network, addr string,
ctx *Ctx, flags DialFlags) (*Conn, error) {
d := net.Dialer{}
return dialSession(d, context, network, addr, ctx, flags, nil)
}

// DialSession will connect to network/address and then wrap the corresponding
Expand All @@ -109,11 +122,11 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
var d net.Dialer
return dialSession(d, network, addr, ctx, flags, session)
return dialSession(d, context.Background(), network, addr, ctx, flags, session)
}

func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
func dialSession(d net.Dialer, context context.Context, network, addr string,
ctx *Ctx, flags DialFlags, session []byte) (*Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
Expand All @@ -127,7 +140,7 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
// TODO: use operating system default certificate chain?
}

c, err := d.Dial(network, addr)
c, err := d.DialContext(context, network, addr)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit bdadb81

Please sign in to comment.