From b4e3892e2d1633b251602e96e1be9f83c77a82b3 Mon Sep 17 00:00:00 2001 From: Asdine El Hrychy Date: Mon, 18 Jan 2021 18:29:27 +0400 Subject: [PATCH 1/2] Add support for driver.ConnBeginTx --- sqlhooks.go | 30 +++++++++++++++++++++++------- sqlhooks_interface_test.go | 12 ++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sqlhooks.go b/sqlhooks.go index 2964249..fda8d04 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -51,22 +51,29 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { wrapped := &Conn{conn, drv.hooks} if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { - return &ExecerQueryerContextWithSessionResetter{wrapped, + conn = &ExecerQueryerContextWithSessionResetter{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}, - &SessionResetter{wrapped}}, nil + &SessionResetter{wrapped}} } else if isExecer(conn) && isQueryer(conn) { - return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, - &QueryerContext{wrapped}}, nil + conn = &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, + &QueryerContext{wrapped}} } else if isExecer(conn) { // If conn implements an Execer interface, return a driver.Conn which // also implements Execer - return &ExecerContext{wrapped}, nil + conn = &ExecerContext{wrapped} } else if isQueryer(conn) { // If conn implements an Queryer interface, return a driver.Conn which // also implements Queryer - return &QueryerContext{wrapped}, nil + conn = &QueryerContext{wrapped} } - return wrapped, nil + + // If conn implements a ConnBeginTx interface, return a driver.Conn which + // also implements ConnBeginTx + if _, ok := conn.(driver.ConnBeginTx); ok { + conn = &ConnBeginTx{Conn: conn} + } + + return conn, nil } // Conn implements a database/sql.driver.Conn @@ -235,6 +242,15 @@ type SessionResetter struct { *Conn } +// ConnBeginTx implements a database/sql.driver.ConnBeginTx +type ConnBeginTx struct { + driver.Conn +} + +func (conn *ConnBeginTx) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts) +} + // Stmt implements a database/sql/driver.Stmt type Stmt struct { Stmt driver.Stmt diff --git a/sqlhooks_interface_test.go b/sqlhooks_interface_test.go index 462b12b..46b8d52 100644 --- a/sqlhooks_interface_test.go +++ b/sqlhooks_interface_test.go @@ -63,6 +63,12 @@ func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { *FakeConnQueryer *FakeConnSessionResetter }{}, nil + case "ConnBeginTx": + return &struct { + *FakeConnBasic + *FakeConnQueryer + *FakeConnBeginTx + }{}, nil } return nil, errors.New("Fake driver not implemented") @@ -111,6 +117,12 @@ func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { return errors.New("Not implemented") } +type FakeConnBeginTx struct{} + +func (*FakeConnBeginTx) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return nil, errors.New("Not implemented") +} + func TestInterfaces(t *testing.T) { drv := Wrap(&fakeDriver{}, &testHooks{}) From 480358310a5bffa50d364d209f991941ea83c1f4 Mon Sep 17 00:00:00 2001 From: Asdine El Hrychy Date: Wed, 20 Jan 2021 13:42:09 +0400 Subject: [PATCH 2/2] Drop support for drivers not implementing driver.ConnBeginTx --- sqlhooks.go | 38 +++++++++++++++----------------------- sqlhooks_interface_test.go | 29 +++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/sqlhooks.go b/sqlhooks.go index fda8d04..1da05ba 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -49,31 +49,29 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { return conn, err } + // Drivers that don't implement driver.ConnBeginTx are not supported. + if _, ok := conn.(driver.ConnBeginTx); !ok { + return nil, errors.New("driver must implement driver.ConnBeginTx") + } + wrapped := &Conn{conn, drv.hooks} if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { - conn = &ExecerQueryerContextWithSessionResetter{wrapped, + return &ExecerQueryerContextWithSessionResetter{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}, - &SessionResetter{wrapped}} + &SessionResetter{wrapped}}, nil } else if isExecer(conn) && isQueryer(conn) { - conn = &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, - &QueryerContext{wrapped}} + return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, + &QueryerContext{wrapped}}, nil } else if isExecer(conn) { // If conn implements an Execer interface, return a driver.Conn which // also implements Execer - conn = &ExecerContext{wrapped} + return &ExecerContext{wrapped}, nil } else if isQueryer(conn) { // If conn implements an Queryer interface, return a driver.Conn which // also implements Queryer - conn = &QueryerContext{wrapped} - } - - // If conn implements a ConnBeginTx interface, return a driver.Conn which - // also implements ConnBeginTx - if _, ok := conn.(driver.ConnBeginTx); ok { - conn = &ConnBeginTx{Conn: conn} + return &QueryerContext{wrapped}, nil } - - return conn, nil + return wrapped, nil } // Conn implements a database/sql.driver.Conn @@ -104,6 +102,9 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) } func (conn *Conn) Close() error { return conn.Conn.Close() } func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() } +func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts) +} // ExecerContext implements a database/sql.driver.ExecerContext type ExecerContext struct { @@ -242,15 +243,6 @@ type SessionResetter struct { *Conn } -// ConnBeginTx implements a database/sql.driver.ConnBeginTx -type ConnBeginTx struct { - driver.Conn -} - -func (conn *ConnBeginTx) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts) -} - // Stmt implements a database/sql/driver.Stmt type Stmt struct { Stmt driver.Stmt diff --git a/sqlhooks_interface_test.go b/sqlhooks_interface_test.go index 46b8d52..dd55c60 100644 --- a/sqlhooks_interface_test.go +++ b/sqlhooks_interface_test.go @@ -63,12 +63,8 @@ func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { *FakeConnQueryer *FakeConnSessionResetter }{}, nil - case "ConnBeginTx": - return &struct { - *FakeConnBasic - *FakeConnQueryer - *FakeConnBeginTx - }{}, nil + case "NonConnBeginTx": + return &FakeConnUnsupported{}, nil } return nil, errors.New("Fake driver not implemented") @@ -86,6 +82,9 @@ func (*FakeConnBasic) Close() error { func (*FakeConnBasic) Begin() (driver.Tx, error) { return nil, errors.New("Not implemented") } +func (*FakeConnBasic) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) { + return nil, errors.New("Not implemented") +} type FakeConnExecer struct{} @@ -117,9 +116,17 @@ func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { return errors.New("Not implemented") } -type FakeConnBeginTx struct{} +// FakeConnUnsupported implements a database/sql.driver.Conn but doesn't implement +// driver.ConnBeginTx. +type FakeConnUnsupported struct{} -func (*FakeConnBeginTx) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +func (*FakeConnUnsupported) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("Not implemented") +} +func (*FakeConnUnsupported) Close() error { + return errors.New("Not implemented") +} +func (*FakeConnUnsupported) Begin() (driver.Tx, error) { return nil, errors.New("Not implemented") } @@ -135,3 +142,9 @@ func TestInterfaces(t *testing.T) { } } } + +func TestUnsupportedDrivers(t *testing.T) { + drv := Wrap(&fakeDriver{}, &testHooks{}) + _, err := drv.Open("NonConnBeginTx") + require.EqualError(t, err, "driver must implement driver.ConnBeginTx") +}