diff --git a/go/vt/vitessdriver/convert.go b/go/vt/vitessdriver/convert.go index abb25beb000..0bf8c2240d5 100644 --- a/go/vt/vitessdriver/convert.go +++ b/go/vt/vitessdriver/convert.go @@ -109,12 +109,16 @@ func (cv *converter) bindVarsFromNamedValues(args []driver.NamedValue) (map[stri return bindVars, nil } -func newConverter(cfg *Configuration) (c *converter, err error) { - c = &converter{ - location: time.UTC, +func newConverter(cfg *Configuration) (*converter, error) { + c := &converter{location: time.UTC} + if cfg.DefaultLocation == "" { + return c, nil } - if cfg.DefaultLocation != "" { - c.location, err = time.LoadLocation(cfg.DefaultLocation) + + loc, err := time.LoadLocation(cfg.DefaultLocation) + if err != nil { + return nil, err } - return + c.location = loc + return c, nil } diff --git a/go/vt/vitessdriver/driver.go b/go/vt/vitessdriver/driver.go index 638e31523f3..4a965399e9c 100644 --- a/go/vt/vitessdriver/driver.go +++ b/go/vt/vitessdriver/driver.go @@ -41,10 +41,30 @@ var ( // Type-check interfaces. var ( - _ driver.QueryerContext = &conn{} - _ driver.ExecerContext = &conn{} - _ driver.StmtQueryContext = &stmt{} - _ driver.StmtExecContext = &stmt{} + _ interface { + driver.Connector + } = &connector{} + + _ interface { + driver.Driver + driver.DriverContext + } = drv{} + + _ interface { + driver.Conn + driver.ConnBeginTx + driver.ConnPrepareContext + driver.ExecerContext + driver.Pinger + driver.QueryerContext + driver.Tx + } = &conn{} + + _ interface { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext + } = &stmt{} ) func init() { @@ -94,8 +114,7 @@ func OpenWithConfiguration(c Configuration) (*sql.DB, error) { return sql.Open(c.DriverName, json) } -type drv struct { -} +type drv struct{} // Open implements the database/sql/driver.Driver interface. // @@ -112,25 +131,65 @@ type drv struct { // // For a description of the available fields, see the Configuration struct. func (d drv) Open(name string) (driver.Conn, error) { - c := &conn{} - err := json.Unmarshal([]byte(name), c) + conn, err := d.OpenConnector(name) if err != nil { return nil, err } - c.setDefaults() + return conn.Connect(context.Background()) +} - if c.convert, err = newConverter(&c.Configuration); err != nil { +// OpenConnector implements the database/sql/driver.DriverContext interface. +// +// See the documentation of Open for details on the format of name. +func (d drv) OpenConnector(name string) (driver.Connector, error) { + var cfg Configuration + if err := json.Unmarshal([]byte(name), &cfg); err != nil { return nil, err } - if err = c.dial(); err != nil { + cfg.setDefaults() + return d.newConnector(cfg) +} + +// A connector holds immutable state for the creation of additional conns via +// the Connect method. +type connector struct { + drv drv + cfg Configuration + convert *converter +} + +func (d drv) newConnector(cfg Configuration) (driver.Connector, error) { + convert, err := newConverter(&cfg) + if err != nil { return nil, err } - return c, nil + return &connector{ + drv: d, + cfg: cfg, + convert: convert, + }, nil } +// Connect implements the database/sql/driver.Connector interface. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + conn := &conn{ + cfg: c.cfg, + convert: c.convert, + } + + if err := conn.dial(ctx); err != nil { + return nil, err + } + + return conn, nil +} + +// Driver implements the database/sql/driver.Connector interface. +func (c *connector) Driver() driver.Driver { return c.drv } + // Configuration holds all Vitess driver settings. // // Fields with documented default values do not have to be set explicitly. @@ -202,32 +261,32 @@ func (c *Configuration) setDefaults() { } type conn struct { - Configuration + cfg Configuration convert *converter conn *vtgateconn.VTGateConn session *vtgateconn.VTGateSession } -func (c *conn) dial() error { +func (c *conn) dial(ctx context.Context) error { var err error - c.conn, err = vtgateconn.DialProtocol(context.Background(), c.Protocol, c.Address) + c.conn, err = vtgateconn.DialProtocol(ctx, c.cfg.Protocol, c.cfg.Address) if err != nil { return err } - if c.Configuration.SessionToken != "" { - sessionFromToken, err := sessionTokenToSession(c.Configuration.SessionToken) + if c.cfg.SessionToken != "" { + sessionFromToken, err := sessionTokenToSession(c.cfg.SessionToken) if err != nil { return err } c.session = c.conn.SessionFromPb(sessionFromToken) } else { - c.session = c.conn.Session(c.Target, nil) + c.session = c.conn.Session(c.cfg.Target, nil) } return nil } func (c *conn) Ping(ctx context.Context) error { - if c.Streaming { + if c.cfg.Streaming { return errors.New("Ping not allowed for streaming connections") } @@ -378,7 +437,7 @@ func sessionTokenToSession(sessionToken string) (*vtgatepb.Session, error) { func (c *conn) Begin() (driver.Tx, error) { // if we're loading from an existing session, we need to avoid starting a new transaction - if c.Configuration.SessionToken != "" { + if c.cfg.SessionToken != "" { return c, nil } @@ -401,7 +460,7 @@ func (c *conn) Commit() error { // if we're loading from an existing session, disallow committing/rolling back the transaction // this isn't a technical limitation, but is enforced to prevent misuse, so that only // the original creator of the transaction can commit/rollback - if c.Configuration.SessionToken != "" { + if c.cfg.SessionToken != "" { return errors.New("calling Commit from a distributed tx is not allowed") } @@ -413,7 +472,7 @@ func (c *conn) Rollback() error { // if we're loading from an existing session, disallow committing/rolling back the transaction // this isn't a technical limitation, but is enforced to prevent misuse, so that only // the original creator of the transaction can commit/rollback - if c.Configuration.SessionToken != "" { + if c.cfg.SessionToken != "" { return errors.New("calling Rollback from a distributed tx is not allowed") } @@ -424,7 +483,7 @@ func (c *conn) Rollback() error { func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { ctx := context.TODO() - if c.Streaming { + if c.cfg.Streaming { return nil, errors.New("Exec not allowed for streaming connections") } bindVars, err := c.convert.buildBindVars(args) @@ -440,7 +499,7 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - if c.Streaming { + if c.cfg.Streaming { return nil, errors.New("Exec not allowed for streaming connections") } @@ -462,7 +521,7 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { return nil, err } - if c.Streaming { + if c.cfg.Streaming { stream, err := c.session.StreamExecute(ctx, query, bindVars) if err != nil { return nil, err @@ -488,7 +547,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam return nil, err } - if c.Streaming { + if c.cfg.Streaming { stream, err := c.session.StreamExecute(ctx, query, bv) if err != nil { return nil, err diff --git a/go/vt/vitessdriver/driver_test.go b/go/vt/vitessdriver/driver_test.go index b1bdd2a833f..bd49a0acd0a 100644 --- a/go/vt/vitessdriver/driver_test.go +++ b/go/vt/vitessdriver/driver_test.go @@ -38,9 +38,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/grpcvtgateservice" ) -var ( - testAddress string -) +var testAddress string // TestMain tests the Vitess Go SQL driver. // @@ -71,7 +69,7 @@ func TestOpen(t *testing.T) { panic(err) } - var testcases = []struct { + testcases := []struct { desc string connStr string conn *conn @@ -80,7 +78,7 @@ func TestOpen(t *testing.T) { desc: "Open()", connStr: fmt.Sprintf(`{"address": "%s", "target": "@replica", "timeout": %d}`, testAddress, int64(30*time.Second)), conn: &conn{ - Configuration: Configuration{ + cfg: Configuration{ Protocol: "grpc", DriverName: "vitess", Target: "@replica", @@ -94,7 +92,7 @@ func TestOpen(t *testing.T) { desc: "Open() (defaults omitted)", connStr: fmt.Sprintf(`{"address": "%s", "timeout": %d}`, testAddress, int64(30*time.Second)), conn: &conn{ - Configuration: Configuration{ + cfg: Configuration{ Protocol: "grpc", DriverName: "vitess", }, @@ -107,7 +105,7 @@ func TestOpen(t *testing.T) { desc: "Open() with keyspace", connStr: fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "target": "ks:0@replica", "timeout": %d}`, testAddress, int64(30*time.Second)), conn: &conn{ - Configuration: Configuration{ + cfg: Configuration{ Protocol: "grpc", DriverName: "vitess", Target: "ks:0@replica", @@ -123,7 +121,7 @@ func TestOpen(t *testing.T) { `{"address": "%s", "timeout": %d, "defaultlocation": "America/Los_Angeles"}`, testAddress, int64(30*time.Second)), conn: &conn{ - Configuration: Configuration{ + cfg: Configuration{ Protocol: "grpc", DriverName: "vitess", DefaultLocation: "America/Los_Angeles", @@ -144,7 +142,7 @@ func TestOpen(t *testing.T) { wantc := tc.conn newc := *(c.(*conn)) - newc.Address = "" + newc.cfg.Address = "" newc.conn = nil newc.session = nil if !reflect.DeepEqual(&newc, wantc) { @@ -255,7 +253,7 @@ func TestExecStreamingNotAllowed(t *testing.T) { } func TestQuery(t *testing.T) { - var testcases = []struct { + testcases := []struct { desc string config Configuration requestName string @@ -357,7 +355,7 @@ func TestQuery(t *testing.T) { } func TestBindVars(t *testing.T) { - var testcases = []struct { + testcases := []struct { desc string in []driver.NamedValue out map[string]*querypb.BindVariable @@ -440,7 +438,7 @@ func TestBindVars(t *testing.T) { } func TestDatetimeQuery(t *testing.T) { - var testcases = []struct { + testcases := []struct { desc string config Configuration requestName string @@ -763,3 +761,103 @@ func colList(fields []*querypb.Field) []string { } return cols } + +func TestConnSeparateSessions(t *testing.T) { + c := Configuration{ + Protocol: "grpc", + Address: testAddress, + Target: "@primary", + } + + db, err := OpenWithConfiguration(c) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each new connection starts a fresh session pointed at @primary. When the + // USE statement is executed, we simulate a change to that individual + // connection's target string. + // + // No connections are returned to the pool during this test and therefore + // the connection state should not be shared. + var conns []*sql.Conn + for i := 0; i < 3; i++ { + sconn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conns = append(conns, sconn) + + targets := []string{targetString(t, sconn)} + + _, err = sconn.ExecContext(ctx, "use @rdonly") + require.NoError(t, err) + + targets = append(targets, targetString(t, sconn)) + + require.Equal(t, []string{"@primary", "@rdonly"}, targets) + } + + for _, c := range conns { + require.NoError(t, c.Close()) + } +} + +func TestConnReuseSessions(t *testing.T) { + c := Configuration{ + Protocol: "grpc", + Address: testAddress, + Target: "@primary", + } + + db, err := OpenWithConfiguration(c) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Pull an individual connection from the pool and execute a USE, resulting + // in changing the target string. We return the connection to the pool + // continuously in this test and verify that we keep pulling the same + // connection with its target string altered. + sconn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + _, err = sconn.ExecContext(ctx, "use @rdonly") + require.NoError(t, err) + require.NoError(t, sconn.Close()) + + var targets []string + for i := 0; i < 3; i++ { + sconn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + targets = append(targets, targetString(t, sconn)) + require.NoError(t, sconn.Close()) + } + + require.Equal(t, []string{"@rdonly", "@rdonly", "@rdonly"}, targets) +} + +func targetString(t *testing.T, c *sql.Conn) string { + t.Helper() + + var target string + require.NoError(t, c.Raw(func(driverConn any) error { + target = driverConn.(*conn).session.SessionPb().TargetString + return nil + })) + + return target +} diff --git a/go/vt/vitessdriver/fakeserver_test.go b/go/vt/vitessdriver/fakeserver_test.go index c420067f61f..a74e44e682c 100644 --- a/go/vt/vitessdriver/fakeserver_test.go +++ b/go/vt/vitessdriver/fakeserver_test.go @@ -33,8 +33,7 @@ import ( ) // fakeVTGateService has the server side of this fake -type fakeVTGateService struct { -} +type fakeVTGateService struct{} // queryExecute contains all the fields we use to test Execute type queryExecute struct { @@ -282,6 +281,20 @@ var execMap = map[string]struct { TargetString: "@primary", }, }, + "use @rdonly": { + execQuery: &queryExecute{ + SQL: "use @rdonly", + Session: &vtgatepb.Session{ + TargetString: "@primary", + Autocommit: true, + }, + }, + result: &sqltypes.Result{}, + session: &vtgatepb.Session{ + TargetString: "@rdonly", + SessionUUID: "1111", + }, + }, } var result1 = sqltypes.Result{