Skip to content

Commit

Permalink
chore: extend sql middleware for mssql, azure_synapse and clickhouse
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed Aug 22, 2023
1 parent e6892a2 commit 3d286f3
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 185 deletions.
92 changes: 63 additions & 29 deletions warehouse/integrations/azure-synapse/azure-synapse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import (
"unicode/utf16"
"unicode/utf8"

"github.com/rudderlabs/rudder-go-kit/stats"
sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
"github.com/rudderlabs/rudder-server/warehouse/logfield"

"golang.org/x/exp/slices"

"github.com/rudderlabs/rudder-server/warehouse/internal/service/loadfiles/downloader"
Expand Down Expand Up @@ -84,17 +88,20 @@ var mssqlDataTypesMapToRudder = map[string]string{
}

type AzureSynapse struct {
DB *sql.DB
DB *sqlmw.DB
Namespace string
ObjectStorage string
Warehouse model.Warehouse
Uploader warehouseutils.Uploader
connectTimeout time.Duration
logger logger.Logger
LoadFileDownLoader downloader.Downloader

stats stats.Stats
logger logger.Logger

config struct {
numWorkersDownloadLoadFiles int
slowQueryThreshold time.Duration
}
}

Expand All @@ -120,51 +127,64 @@ var partitionKeyMap = map[string]string{
warehouseutils.DiscardsTable: "row_id, column_name, table_name",
}

func New(conf *config.Config, log logger.Logger) *AzureSynapse {
az := &AzureSynapse{}

az.logger = log.Child("integrations").Child("synapse")
func New(conf *config.Config, log logger.Logger, stats stats.Stats) *AzureSynapse {
az := &AzureSynapse{
stats: stats,
logger: log.Child("integrations").Child("synapse"),
}

az.config.numWorkersDownloadLoadFiles = conf.GetInt("Warehouse.azure_synapse.numWorkersDownloadLoadFiles", 1)
az.config.slowQueryThreshold = conf.GetDuration("Warehouse.azure_synapse.slowQueryThreshold", 5, time.Minute)

return az
}

func connect(cred credentials) (*sql.DB, error) {
// Create connection string
// url := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%s;database=%s;encrypt=%s;TrustServerCertificate=true", cred.host, cred.user, cred.password, cred.port, cred.dbName, cred.sslMode)
// Encryption options : disable, false, true. https://github.com/denisenkom/go-mssqldb
// TrustServerCertificate=true ; all options(disable, false, true) work with this
// if rds.forcessl=1; disable option doesn't work. true, false works alongside TrustServerCertificate=true
// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/SQLServer.Concepts.General.SSL.Using.html
// more combination explanations here: https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/connection-string-keywords-and-data-source-names-dsns?view=sql-server-ver15
// connect to the azure synapse database
// if TrustServerCertificate is set to true, all options(disable, false, true) works.
// if forceSSL is set to 1, disable option doesn't work.
// If forceSSL is set to true or false, it works alongside with TrustServerCertificate=true
// more about combinations in here: https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/connection-string-keywords-and-data-source-names-dsns?view=sql-server-ver15
func (as *AzureSynapse) connect() (*sqlmw.DB, error) {
cred := as.connectionCredentials()

port, err := strconv.Atoi(cred.port)
if err != nil {
return nil, fmt.Errorf("invalid port %q: %w", cred.port, err)
}

Check warning on line 153 in warehouse/integrations/azure-synapse/azure-synapse.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/azure-synapse/azure-synapse.go#L152-L153

Added lines #L152 - L153 were not covered by tests

query := url.Values{}
query.Add("database", cred.dbName)
query.Add("encrypt", cred.sslMode)
query.Add("TrustServerCertificate", "true")
if cred.timeout > 0 {
query.Add("dial timeout", fmt.Sprintf("%d", cred.timeout/time.Second))
}
query.Add("TrustServerCertificate", "true")
port, err := strconv.Atoi(cred.port)
if err != nil {
return nil, fmt.Errorf("invalid port: %w", err)
}

connUrl := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(cred.user, cred.password),
Host: net.JoinHostPort(cred.host, strconv.Itoa(port)),
RawQuery: query.Encode(),
}

var db *sql.DB
if db, err = sql.Open("sqlserver", connUrl.String()); err != nil {
return nil, fmt.Errorf("synapse connection error : (%v)", err)
db, err := sql.Open("sqlserver", connUrl.String())
if err != nil {
return nil, fmt.Errorf("opening connection: %w", err)

Check warning on line 172 in warehouse/integrations/azure-synapse/azure-synapse.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/azure-synapse/azure-synapse.go#L172

Added line #L172 was not covered by tests
}
return db, nil

middleware := sqlmw.New(
db,
sqlmw.WithStats(as.stats),
sqlmw.WithLogger(as.logger),
sqlmw.WithKeyAndValues(as.defaultLogFields()),
sqlmw.WithQueryTimeout(as.connectTimeout),
sqlmw.WithSlowQueryThreshold(as.config.slowQueryThreshold),
)
return middleware, nil
}

func (as *AzureSynapse) getConnectionCredentials() credentials {
return credentials{
func (as *AzureSynapse) connectionCredentials() *credentials {
return &credentials{
host: warehouseutils.GetConfigValue(host, as.Warehouse),
dbName: warehouseutils.GetConfigValue(dbName, as.Warehouse),
user: warehouseutils.GetConfigValue(user, as.Warehouse),
Expand All @@ -175,6 +195,17 @@ func (as *AzureSynapse) getConnectionCredentials() credentials {
}
}

func (as *AzureSynapse) defaultLogFields() []any {
return []any{
logfield.SourceID, as.Warehouse.Source.ID,
logfield.SourceType, as.Warehouse.Source.SourceDefinition.Name,
logfield.DestinationID, as.Warehouse.Destination.ID,
logfield.DestinationType, as.Warehouse.Destination.DestinationDefinition.Name,
logfield.WorkspaceID, as.Warehouse.WorkspaceID,
logfield.Namespace, as.Namespace,
}
}

func columnsWithDataTypes(columns model.TableSchema, prefix string) string {
var arr []string
for name, dataType := range columns {
Expand Down Expand Up @@ -665,7 +696,9 @@ func (as *AzureSynapse) Setup(_ context.Context, warehouse model.Warehouse, uplo
as.ObjectStorage = warehouseutils.ObjectStorageType(warehouseutils.AzureSynapse, warehouse.Destination.Config, as.Uploader.UseRudderStorage())
as.LoadFileDownLoader = downloader.NewDownloader(&warehouse, uploader, as.config.numWorkersDownloadLoadFiles)

as.DB, err = connect(as.getConnectionCredentials())
if as.DB, err = as.connect(); err != nil {
return fmt.Errorf("connecting to azure synapse: %w", err)
}

Check warning on line 701 in warehouse/integrations/azure-synapse/azure-synapse.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/azure-synapse/azure-synapse.go#L700-L701

Added lines #L700 - L701 were not covered by tests
return err
}

Expand Down Expand Up @@ -823,12 +856,13 @@ func (as *AzureSynapse) GetTotalCountInTable(ctx context.Context, tableName stri
func (as *AzureSynapse) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) {
as.Warehouse = warehouse
as.Namespace = warehouse.Namespace
dbHandle, err := connect(as.getConnectionCredentials())

db, err := as.connect()

Check warning on line 860 in warehouse/integrations/azure-synapse/azure-synapse.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/azure-synapse/azure-synapse.go#L859-L860

Added lines #L859 - L860 were not covered by tests
if err != nil {
return client.Client{}, err
return client.Client{}, fmt.Errorf("connecting to azure synapse: %w", err)

Check warning on line 862 in warehouse/integrations/azure-synapse/azure-synapse.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/azure-synapse/azure-synapse.go#L862

Added line #L862 was not covered by tests
}

return client.Client{Type: client.SQLClient, SQL: dbHandle}, err
return client.Client{Type: client.SQLClient, SQL: db.DB}, err

Check warning on line 865 in warehouse/integrations/azure-synapse/azure-synapse.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/azure-synapse/azure-synapse.go#L865

Added line #L865 was not covered by tests
}

func (as *AzureSynapse) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) {
Expand Down
Loading

0 comments on commit 3d286f3

Please sign in to comment.