Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

go/vt/vitessdriver: implement driver.{Connector,DriverContext} #13704

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions go/vt/vitessdriver/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,16 @@ func (cv *converter) bindVarsFromNamedValues(args []driver.NamedValue) (map[stri
return bindVars, nil
}

func newConverter(cfg *Configuration) (c *converter, err error) {
c = &converter{
location: time.UTC,
func newConverter(cfg *Configuration) (*converter, error) {
mdlayher marked this conversation as resolved.
Show resolved Hide resolved
c := &converter{location: time.UTC}
if cfg.DefaultLocation == "" {
return c, nil
}
if cfg.DefaultLocation != "" {
c.location, err = time.LoadLocation(cfg.DefaultLocation)

loc, err := time.LoadLocation(cfg.DefaultLocation)
if err != nil {
return nil, err
}
return
c.location = loc
return c, nil
}
111 changes: 85 additions & 26 deletions go/vt/vitessdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,30 @@ var (

// Type-check interfaces.
var (
_ driver.QueryerContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.StmtQueryContext = &stmt{}
_ driver.StmtExecContext = &stmt{}
_ interface {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now exhaustively enumerate the optional interfaces implemented by these types since it's confusing to tell what exactly is being implemented in the driver package.

driver.Connector
} = &connector{}

_ interface {
driver.Driver
driver.DriverContext
} = drv{}

_ interface {
driver.Conn
driver.ConnBeginTx
driver.ConnPrepareContext
driver.ExecerContext
driver.Pinger
driver.QueryerContext
driver.Tx
} = &conn{}

_ interface {
driver.Stmt
driver.StmtExecContext
driver.StmtQueryContext
} = &stmt{}
)

func init() {
Expand Down Expand Up @@ -94,8 +114,7 @@ func OpenWithConfiguration(c Configuration) (*sql.DB, error) {
return sql.Open(c.DriverName, json)
}

type drv struct {
}
type drv struct{}

// Open implements the database/sql/driver.Driver interface.
//
Expand All @@ -112,25 +131,65 @@ type drv struct {
//
// For a description of the available fields, see the Configuration struct.
func (d drv) Open(name string) (driver.Conn, error) {
c := &conn{}
err := json.Unmarshal([]byte(name), c)
conn, err := d.OpenConnector(name)
if err != nil {
return nil, err
}

c.setDefaults()
return conn.Connect(context.Background())
}

if c.convert, err = newConverter(&c.Configuration); err != nil {
// OpenConnector implements the database/sql/driver.DriverContext interface.
//
// See the documentation of Open for details on the format of name.
func (d drv) OpenConnector(name string) (driver.Connector, error) {
var cfg Configuration
if err := json.Unmarshal([]byte(name), &cfg); err != nil {
return nil, err
}

if err = c.dial(); err != nil {
cfg.setDefaults()
return d.newConnector(cfg)
}

// A connector holds immutable state for the creation of additional conns via
// the Connect method.
type connector struct {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new type is the key; it stores immutable config/convert/driver state and then enables any number of new connections to be dialed from it, along with added context support.

drv drv
cfg Configuration
convert *converter
}

func (d drv) newConnector(cfg Configuration) (driver.Connector, error) {
convert, err := newConverter(&cfg)
if err != nil {
return nil, err
}

return c, nil
return &connector{
drv: d,
cfg: cfg,
convert: convert,
}, nil
}

// Connect implements the database/sql/driver.Connector interface.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
conn := &conn{
cfg: c.cfg,
convert: c.convert,
}

if err := conn.dial(ctx); err != nil {
return nil, err
}

return conn, nil
}

// Driver implements the database/sql/driver.Connector interface.
func (c *connector) Driver() driver.Driver { return c.drv }

// Configuration holds all Vitess driver settings.
//
// Fields with documented default values do not have to be set explicitly.
Expand Down Expand Up @@ -202,32 +261,32 @@ func (c *Configuration) setDefaults() {
}

type conn struct {
Configuration
cfg Configuration
mdlayher marked this conversation as resolved.
Show resolved Hide resolved
convert *converter
conn *vtgateconn.VTGateConn
session *vtgateconn.VTGateSession
}

func (c *conn) dial() error {
func (c *conn) dial(ctx context.Context) error {
var err error
c.conn, err = vtgateconn.DialProtocol(context.Background(), c.Protocol, c.Address)
c.conn, err = vtgateconn.DialProtocol(ctx, c.cfg.Protocol, c.cfg.Address)
if err != nil {
return err
}
if c.Configuration.SessionToken != "" {
sessionFromToken, err := sessionTokenToSession(c.Configuration.SessionToken)
if c.cfg.SessionToken != "" {
sessionFromToken, err := sessionTokenToSession(c.cfg.SessionToken)
if err != nil {
return err
}
c.session = c.conn.SessionFromPb(sessionFromToken)
} else {
c.session = c.conn.Session(c.Target, nil)
c.session = c.conn.Session(c.cfg.Target, nil)
}
return nil
}

func (c *conn) Ping(ctx context.Context) error {
if c.Streaming {
if c.cfg.Streaming {
return errors.New("Ping not allowed for streaming connections")
}

Expand Down Expand Up @@ -378,7 +437,7 @@ func sessionTokenToSession(sessionToken string) (*vtgatepb.Session, error) {

func (c *conn) Begin() (driver.Tx, error) {
// if we're loading from an existing session, we need to avoid starting a new transaction
if c.Configuration.SessionToken != "" {
if c.cfg.SessionToken != "" {
return c, nil
}

Expand All @@ -401,7 +460,7 @@ func (c *conn) Commit() error {
// if we're loading from an existing session, disallow committing/rolling back the transaction
// this isn't a technical limitation, but is enforced to prevent misuse, so that only
// the original creator of the transaction can commit/rollback
if c.Configuration.SessionToken != "" {
if c.cfg.SessionToken != "" {
return errors.New("calling Commit from a distributed tx is not allowed")
}

Expand All @@ -413,7 +472,7 @@ func (c *conn) Rollback() error {
// if we're loading from an existing session, disallow committing/rolling back the transaction
// this isn't a technical limitation, but is enforced to prevent misuse, so that only
// the original creator of the transaction can commit/rollback
if c.Configuration.SessionToken != "" {
if c.cfg.SessionToken != "" {
return errors.New("calling Rollback from a distributed tx is not allowed")
}

Expand All @@ -424,7 +483,7 @@ func (c *conn) Rollback() error {
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
ctx := context.TODO()

if c.Streaming {
if c.cfg.Streaming {
return nil, errors.New("Exec not allowed for streaming connections")
}
bindVars, err := c.convert.buildBindVars(args)
Expand All @@ -440,7 +499,7 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.Streaming {
if c.cfg.Streaming {
return nil, errors.New("Exec not allowed for streaming connections")
}

Expand All @@ -462,7 +521,7 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return nil, err
}

if c.Streaming {
if c.cfg.Streaming {
stream, err := c.session.StreamExecute(ctx, query, bindVars)
if err != nil {
return nil, err
Expand All @@ -488,7 +547,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, err
}

if c.Streaming {
if c.cfg.Streaming {
stream, err := c.session.StreamExecute(ctx, query, bv)
if err != nil {
return nil, err
Expand Down
Loading
Loading