Skip to content

Commit

Permalink
chore(warehouse): cleanup for test connection (#3226)
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed Apr 20, 2023
1 parent cae0093 commit 51c1ac0
Show file tree
Hide file tree
Showing 15 changed files with 66 additions and 143 deletions.
21 changes: 6 additions & 15 deletions warehouse/integrations/azure-synapse/azure-synapse.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"database/sql"
"encoding/csv"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -638,23 +639,13 @@ func (*AzureSynapse) AlterColumn(_, _, _ string) (model.AlterTableResponse, erro
return model.AlterTableResponse{}, nil
}

func (as *AzureSynapse) TestConnection(warehouse model.Warehouse) (err error) {
as.Warehouse = warehouse
as.DB, err = connect(as.getConnectionCredentials())
if err != nil {
return
}
defer as.DB.Close()

ctx, cancel := context.WithTimeout(context.TODO(), as.ConnectTimeout)
defer cancel()

err = as.DB.PingContext(ctx)
if err == context.DeadlineExceeded {
return fmt.Errorf("connection testing timed out after %d sec", as.ConnectTimeout/time.Second)
func (as *AzureSynapse) TestConnection(ctx context.Context, _ model.Warehouse) error {
err := as.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
}
if err != nil {
return err
return fmt.Errorf("pinging: %w", err)
}

return nil
Expand Down
13 changes: 2 additions & 11 deletions warehouse/integrations/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -819,17 +819,8 @@ func (bq *BigQuery) Setup(warehouse model.Warehouse, uploader warehouseutils.Upl
return err
}

func (bq *BigQuery) TestConnection(warehouse model.Warehouse) (err error) {
bq.warehouse = warehouse
bq.db, err = bq.connect(BQCredentials{
ProjectID: bq.projectID,
Credentials: warehouseutils.GetConfigValue(GCPCredentials, bq.warehouse),
})
if err != nil {
return
}
defer func() { _ = bq.db.Close() }()
return
func (bq *BigQuery) TestConnection(context.Context, model.Warehouse) (err error) {
return nil
}

func (bq *BigQuery) LoadTable(tableName string) error {
Expand Down
20 changes: 5 additions & 15 deletions warehouse/integrations/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,23 +918,13 @@ func (*Clickhouse) AlterColumn(_, _, _ string) (model.AlterTableResponse, error)
}

// TestConnection is used destination connection tester to test the clickhouse connection
func (ch *Clickhouse) TestConnection(warehouse model.Warehouse) (err error) {
ch.Warehouse = warehouse
ch.DB, err = ch.ConnectToClickhouse(ch.getConnectionCredentials(), true)
if err != nil {
return
}
defer ch.DB.Close()

ctx, cancel := context.WithTimeout(context.TODO(), ch.ConnectTimeout)
defer cancel()

err = ch.DB.PingContext(ctx)
if err == context.DeadlineExceeded {
return fmt.Errorf("connection testing timed out after %d sec", ch.ConnectTimeout/time.Second)
func (ch *Clickhouse) TestConnection(ctx context.Context, _ model.Warehouse) error {
err := ch.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
}
if err != nil {
return err
return fmt.Errorf("pinging: %w", err)
}

return nil
Expand Down
10 changes: 8 additions & 2 deletions warehouse/integrations/clickhouse/clickhouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ func TestHandle_TestConnection(t *testing.T) {
}{
{
name: "DeadlineExceeded",
wantError: errors.New("connection testing timed out after 0 sec"),
wantError: errors.New("connection timeout: context deadline exceeded"),
},
{
name: "Success",
Expand Down Expand Up @@ -708,9 +708,15 @@ func TestHandle_TestConnection(t *testing.T) {
},
}

err = ch.Setup(warehouse, &mockUploader{})
require.NoError(t, err)

ch.SetConnectionTimeout(tc.timeout)

err := ch.TestConnection(warehouse)
ctx, cancel := context.WithTimeout(context.TODO(), tc.timeout)
defer cancel()

err := ch.TestConnection(ctx, warehouse)
if tc.wantError != nil {
require.ErrorContains(t, err, tc.wantError.Error())
return
Expand Down
2 changes: 1 addition & 1 deletion warehouse/integrations/datalake/datalake.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (*Datalake) IsEmpty(_ model.Warehouse) (bool, error) {
return false, nil
}

func (*Datalake) TestConnection(_ model.Warehouse) error {
func (*Datalake) TestConnection(context.Context, model.Warehouse) error {
return fmt.Errorf("datalake err :not implemented")
}

Expand Down
5 changes: 1 addition & 4 deletions warehouse/integrations/deltalake-native/deltalake.go
Original file line number Diff line number Diff line change
Expand Up @@ -1162,10 +1162,7 @@ func (*Deltalake) IsEmpty(model.Warehouse) (bool, error) {
}

// TestConnection tests the connection to the warehouse
func (d *Deltalake) TestConnection(model.Warehouse) error {
ctx, cancel := context.WithTimeout(context.Background(), d.ConnectTimeout)
defer cancel()

func (d *Deltalake) TestConnection(ctx context.Context, _ model.Warehouse) error {
err := d.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
Expand Down
6 changes: 2 additions & 4 deletions warehouse/integrations/deltalake/deltalake.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,10 +952,8 @@ func (dl *Deltalake) Setup(warehouse model.Warehouse, uploader warehouseutils.Up
}

// TestConnection test the connection for the warehouse
func (dl *Deltalake) TestConnection(warehouse model.Warehouse) (err error) {
dl.Warehouse = warehouse
dl.Client, err = dl.connectToWarehouse()
return
func (dl *Deltalake) TestConnection(context.Context, model.Warehouse) error {
return nil
}

// Cleanup cleanup when upload is done.
Expand Down
2 changes: 1 addition & 1 deletion warehouse/integrations/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Manager interface {
LoadIdentityMappingsTable() error
Cleanup()
IsEmpty(warehouse model.Warehouse) (bool, error)
TestConnection(warehouse model.Warehouse) error
TestConnection(ctx context.Context, warehouse model.Warehouse) error
DownloadIdentityRules(*misc.GZipWriter) error
GetTotalCountInTable(ctx context.Context, tableName string) (int64, error)
Connect(warehouse model.Warehouse) (client.Client, error)
Expand Down
27 changes: 6 additions & 21 deletions warehouse/integrations/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"database/sql"
"encoding/csv"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -680,29 +681,13 @@ func (*MSSQL) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) {
return model.AlterTableResponse{}, nil
}

func (ms *MSSQL) TestConnection(warehouse model.Warehouse) (err error) {
ms.Warehouse = warehouse
ms.Namespace = warehouse.Namespace
ms.ObjectStorage = warehouseutils.ObjectStorageType(
warehouseutils.MSSQL,
warehouse.Destination.Config,
misc.IsConfiguredToUseRudderObjectStorage(ms.Warehouse.Destination.Config),
)
ms.DB, err = Connect(ms.getConnectionCredentials())
if err != nil {
return
}
defer ms.DB.Close()

ctx, cancel := context.WithTimeout(context.TODO(), ms.ConnectTimeout)
defer cancel()

err = ms.DB.PingContext(ctx)
if err == context.DeadlineExceeded {
return fmt.Errorf("connection testing timed out after %d sec", ms.ConnectTimeout/time.Second)
func (ms *MSSQL) TestConnection(ctx context.Context, _ model.Warehouse) error {
err := ms.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
}
if err != nil {
return err
return fmt.Errorf("pinging: %w", err)
}

return nil
Expand Down
26 changes: 8 additions & 18 deletions warehouse/integrations/postgres-legacy/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"database/sql"
"encoding/csv"
"errors"
"fmt"
"io"
"net/url"
Expand Down Expand Up @@ -784,30 +785,19 @@ func (*Postgres) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) {
return model.AlterTableResponse{}, nil
}

func (pg *Postgres) TestConnection(warehouse model.Warehouse) (err error) {
if warehouse.Destination.Config["sslMode"] == "verify-ca" {
func (pg *Postgres) TestConnection(ctx context.Context, warehouse model.Warehouse) error {
if warehouse.Destination.Config["sslMode"] == verifyCA {
if sslKeyError := warehouseutils.WriteSSLKeys(warehouse.Destination); sslKeyError.IsError() {
pg.logger.Error(sslKeyError.Error())
err = fmt.Errorf(sslKeyError.Error())
return
return fmt.Errorf("writing ssl keys: %s", sslKeyError.Error())
}
}
pg.Warehouse = warehouse
pg.DB, err = Connect(pg.getConnectionCredentials())
if err != nil {
return
}
defer pg.DB.Close()

ctx, cancel := context.WithTimeout(context.TODO(), pg.ConnectTimeout)
defer cancel()

err = pg.DB.PingContext(ctx)
if err == context.DeadlineExceeded {
return fmt.Errorf("connection testing timed out after %d sec", pg.ConnectTimeout/time.Second)
err := pg.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
}
if err != nil {
return err
return fmt.Errorf("pinging: %w", err)
}

return nil
Expand Down
26 changes: 8 additions & 18 deletions warehouse/integrations/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"regexp"
Expand Down Expand Up @@ -346,30 +347,19 @@ func (*Postgres) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) {
return model.AlterTableResponse{}, nil
}

func (pg *Postgres) TestConnection(warehouse model.Warehouse) (err error) {
if warehouse.Destination.Config["sslMode"] == "verify-ca" {
func (pg *Postgres) TestConnection(ctx context.Context, warehouse model.Warehouse) error {
if warehouse.Destination.Config["sslMode"] == verifyCA {
if sslKeyError := warehouseutils.WriteSSLKeys(warehouse.Destination); sslKeyError.IsError() {
pg.Logger.Error(sslKeyError.Error())
err = fmt.Errorf(sslKeyError.Error())
return
return fmt.Errorf("writing ssl keys: %s", sslKeyError.Error())
}
}
pg.Warehouse = warehouse
pg.DB, err = Connect(pg.getConnectionCredentials())
if err != nil {
return
}
defer pg.DB.Close()

ctx, cancel := context.WithTimeout(context.TODO(), pg.ConnectTimeout)
defer cancel()

err = pg.DB.PingContext(ctx)
if err == context.DeadlineExceeded {
return fmt.Errorf("connection testing timed out after %d sec", pg.ConnectTimeout/time.Second)
err := pg.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
}
if err != nil {
return err
return fmt.Errorf("pinging: %w", err)
}

return nil
Expand Down
23 changes: 7 additions & 16 deletions warehouse/integrations/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
Expand Down Expand Up @@ -1351,26 +1352,16 @@ func (rs *Redshift) Setup(warehouse model.Warehouse, uploader warehouseutils.Upl
return err
}

func (rs *Redshift) TestConnection(warehouse model.Warehouse) (err error) {
rs.Warehouse = warehouse
rs.DB, err = Connect(rs.getConnectionCredentials())
if err != nil {
return
}
defer func() { _ = rs.DB.Close() }()

ctx, cancel := context.WithTimeout(context.TODO(), rs.ConnectTimeout)
defer cancel()

err = rs.DB.PingContext(ctx)
if err == context.DeadlineExceeded {
return fmt.Errorf("connection testing timed out after %d sec", rs.ConnectTimeout/time.Second)
func (rs *Redshift) TestConnection(ctx context.Context, _ model.Warehouse) error {
err := rs.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
}
if err != nil {
return err
return fmt.Errorf("pinging: %w", err)
}

return
return nil
}

func (rs *Redshift) Cleanup() {
Expand Down
21 changes: 6 additions & 15 deletions warehouse/integrations/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"database/sql"
"encoding/csv"
"errors"
"fmt"
"regexp"
"sort"
Expand Down Expand Up @@ -1252,23 +1253,13 @@ func (sf *Snowflake) Setup(warehouse model.Warehouse, uploader warehouseutils.Up
return err
}

func (sf *Snowflake) TestConnection(warehouse model.Warehouse) (err error) {
sf.Warehouse = warehouse
sf.DB, err = Connect(sf.getConnectionCredentials(optionalCreds{}))
if err != nil {
return
}
defer sf.DB.Close()

ctx, cancel := context.WithTimeout(context.TODO(), sf.ConnectTimeout)
defer cancel()

err = sf.DB.PingContext(ctx)
if err == context.DeadlineExceeded {
return fmt.Errorf("connection testing timed out after %d sec", sf.ConnectTimeout/time.Second)
func (sf *Snowflake) TestConnection(ctx context.Context, _ model.Warehouse) error {
err := sf.DB.PingContext(ctx)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("connection timeout: %w", err)
}
if err != nil {
return err
return fmt.Errorf("pinging: %w", err)
}

return nil
Expand Down
5 changes: 4 additions & 1 deletion warehouse/validations/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,10 @@ func (os *objectStorage) Validate() error {
func (c *connections) Validate() error {
defer c.manager.Cleanup()

return c.manager.TestConnection(createDummyWarehouse(c.destination))
ctx, cancel := context.WithTimeout(context.TODO(), warehouseutils.TestConnectionTimeout)
defer cancel()

return c.manager.TestConnection(ctx, createDummyWarehouse(c.destination))
}

func (cs *createSchema) Validate() error {
Expand Down
2 changes: 1 addition & 1 deletion warehouse/validations/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func TestValidator(t *testing.T) {
config: map[string]interface{}{
"database": "invalid_database",
},
wantError: errors.New("pq: database \"invalid_database\" does not exist"),
wantError: errors.New("pinging: pq: database \"invalid_database\" does not exist"),
},
{
name: "valid credentials",
Expand Down

0 comments on commit 51c1ac0

Please sign in to comment.