diff --git a/warehouse/integrations/azure-synapse/azure-synapse.go b/warehouse/integrations/azure-synapse/azure-synapse.go index 26db9f8eaf0..8f4936626f0 100644 --- a/warehouse/integrations/azure-synapse/azure-synapse.go +++ b/warehouse/integrations/azure-synapse/azure-synapse.go @@ -17,6 +17,10 @@ import ( "unicode/utf16" "unicode/utf8" + "github.com/rudderlabs/rudder-go-kit/stats" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/logfield" + "golang.org/x/exp/slices" "github.com/rudderlabs/rudder-server/warehouse/internal/service/loadfiles/downloader" @@ -84,17 +88,20 @@ var mssqlDataTypesMapToRudder = map[string]string{ } type AzureSynapse struct { - DB *sql.DB + DB *sqlmw.DB Namespace string ObjectStorage string Warehouse model.Warehouse Uploader warehouseutils.Uploader connectTimeout time.Duration - logger logger.Logger LoadFileDownLoader downloader.Downloader + stats stats.Stats + logger logger.Logger + config struct { numWorkersDownloadLoadFiles int + slowQueryThreshold time.Duration } } @@ -120,35 +127,39 @@ var partitionKeyMap = map[string]string{ warehouseutils.DiscardsTable: "row_id, column_name, table_name", } -func New(conf *config.Config, log logger.Logger) *AzureSynapse { - az := &AzureSynapse{} - - az.logger = log.Child("integrations").Child("synapse") +func New(conf *config.Config, log logger.Logger, stats stats.Stats) *AzureSynapse { + az := &AzureSynapse{ + stats: stats, + logger: log.Child("integrations").Child("synapse"), + } az.config.numWorkersDownloadLoadFiles = conf.GetInt("Warehouse.azure_synapse.numWorkersDownloadLoadFiles", 1) + az.config.slowQueryThreshold = conf.GetDuration("Warehouse.azure_synapse.slowQueryThreshold", 5, time.Minute) return az } -func connect(cred credentials) (*sql.DB, error) { - // Create connection string - // url := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%s;database=%s;encrypt=%s;TrustServerCertificate=true", cred.host, cred.user, cred.password, cred.port, cred.dbName, cred.sslMode) - // Encryption options : disable, false, true. https://github.com/denisenkom/go-mssqldb - // TrustServerCertificate=true ; all options(disable, false, true) work with this - // if rds.forcessl=1; disable option doesn't work. true, false works alongside TrustServerCertificate=true - // https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/SQLServer.Concepts.General.SSL.Using.html - // more combination explanations here: https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/connection-string-keywords-and-data-source-names-dsns?view=sql-server-ver15 +// connect to the azure synapse database +// if TrustServerCertificate is set to true, all options(disable, false, true) works. +// if forceSSL is set to 1, disable option doesn't work. +// If forceSSL is set to true or false, it works alongside with TrustServerCertificate=true +// more about combinations in here: https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/connection-string-keywords-and-data-source-names-dsns?view=sql-server-ver15 +func (as *AzureSynapse) connect() (*sqlmw.DB, error) { + cred := as.connectionCredentials() + + port, err := strconv.Atoi(cred.port) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + query := url.Values{} query.Add("database", cred.dbName) query.Add("encrypt", cred.sslMode) + query.Add("TrustServerCertificate", "true") if cred.timeout > 0 { query.Add("dial timeout", fmt.Sprintf("%d", cred.timeout/time.Second)) } - query.Add("TrustServerCertificate", "true") - port, err := strconv.Atoi(cred.port) - if err != nil { - return nil, fmt.Errorf("invalid port: %w", err) - } + connUrl := &url.URL{ Scheme: "sqlserver", User: url.UserPassword(cred.user, cred.password), @@ -156,15 +167,30 @@ func connect(cred credentials) (*sql.DB, error) { RawQuery: query.Encode(), } - var db *sql.DB - if db, err = sql.Open("sqlserver", connUrl.String()); err != nil { - return nil, fmt.Errorf("synapse connection error : (%v)", err) - } - return db, nil + db, err := sql.Open("sqlserver", connUrl.String()) + if err != nil { + return nil, fmt.Errorf("opening connection: %w", err) + } + + middleware := sqlmw.New( + db, + sqlmw.WithStats(as.stats), + sqlmw.WithLogger(as.logger), + sqlmw.WithKeyAndValues( + logfield.SourceID, as.Warehouse.Source.ID, + logfield.SourceType, as.Warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, as.Warehouse.Destination.ID, + logfield.DestinationType, as.Warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, as.Warehouse.WorkspaceID, + logfield.Schema, as.Namespace, + ), + sqlmw.WithSlowQueryThreshold(as.config.slowQueryThreshold), + ) + return middleware, nil } -func (as *AzureSynapse) getConnectionCredentials() credentials { - return credentials{ +func (as *AzureSynapse) connectionCredentials() *credentials { + return &credentials{ host: warehouseutils.GetConfigValue(host, as.Warehouse), dbName: warehouseutils.GetConfigValue(dbName, as.Warehouse), user: warehouseutils.GetConfigValue(user, as.Warehouse), @@ -665,7 +691,9 @@ func (as *AzureSynapse) Setup(_ context.Context, warehouse model.Warehouse, uplo as.ObjectStorage = warehouseutils.ObjectStorageType(warehouseutils.AzureSynapse, warehouse.Destination.Config, as.Uploader.UseRudderStorage()) as.LoadFileDownLoader = downloader.NewDownloader(&warehouse, uploader, as.config.numWorkersDownloadLoadFiles) - as.DB, err = connect(as.getConnectionCredentials()) + if as.DB, err = as.connect(); err != nil { + return fmt.Errorf("connecting to azure synapse: %w", err) + } return err } @@ -823,12 +851,13 @@ func (as *AzureSynapse) GetTotalCountInTable(ctx context.Context, tableName stri func (as *AzureSynapse) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { as.Warehouse = warehouse as.Namespace = warehouse.Namespace - dbHandle, err := connect(as.getConnectionCredentials()) + + db, err := as.connect() if err != nil { - return client.Client{}, err + return client.Client{}, fmt.Errorf("connecting to azure synapse: %w", err) } - return client.Client{Type: client.SQLClient, SQL: dbHandle}, err + return client.Client{Type: client.SQLClient, SQL: db.DB}, err } func (as *AzureSynapse) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) { diff --git a/warehouse/integrations/clickhouse/clickhouse.go b/warehouse/integrations/clickhouse/clickhouse.go index 657cc6a1e25..c10a55a455b 100644 --- a/warehouse/integrations/clickhouse/clickhouse.go +++ b/warehouse/integrations/clickhouse/clickhouse.go @@ -20,6 +20,9 @@ import ( "strings" "time" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/logfield" + "github.com/rudderlabs/rudder-server/warehouse/internal/service/loadfiles/downloader" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -50,7 +53,7 @@ const ( secure = "secure" skipVerify = "skipVerify" caCertificate = "caCertificate" - Cluster = "cluster" + cluster = "cluster" partitionField = "received_at" ) @@ -136,16 +139,17 @@ var errorsMappings = []model.JobError{ } type Clickhouse struct { - DB *sql.DB + DB *sqlmw.DB Namespace string ObjectStorage string Warehouse model.Warehouse Uploader warehouseutils.Uploader connectTimeout time.Duration - logger logger.Logger - stats stats.Stats LoadFileDownloader downloader.Downloader + logger logger.Logger + stats stats.Stats + config struct { queryDebugLogs string blockSize string @@ -159,6 +163,7 @@ type Clickhouse struct { loadTableFailureRetries int numWorkersDownloadLoadFiles int s3EngineEnabledWorkspaceIDs []string + slowQueryThreshold time.Duration } } @@ -214,11 +219,33 @@ func (ch *Clickhouse) newClickHouseStat(tableName string) *clickHouseStat { } } -// connectToClickhouse connects to clickhouse with provided credentials -func (ch *Clickhouse) connectToClickhouse(cred credentials, includeDBInConn bool) (*sql.DB, error) { - dsn := url.URL{ - Scheme: "tcp", - Host: fmt.Sprintf("%s:%s", cred.host, cred.port), +func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Clickhouse { + ch := &Clickhouse{} + + ch.logger = log.Child("integrations").Child("clickhouse") + ch.stats = stat + + ch.config.queryDebugLogs = conf.GetString("Warehouse.clickhouse.queryDebugLogs", "false") + ch.config.blockSize = conf.GetString("Warehouse.clickhouse.blockSize", "1000000") + ch.config.poolSize = conf.GetString("Warehouse.clickhouse.poolSize", "100") + ch.config.readTimeout = conf.GetString("Warehouse.clickhouse.readTimeout", "300") + ch.config.writeTimeout = conf.GetString("Warehouse.clickhouse.writeTimeout", "1800") + ch.config.compress = conf.GetBool("Warehouse.clickhouse.compress", false) + ch.config.disableNullable = conf.GetBool("Warehouse.clickhouse.disableNullable", false) + ch.config.execTimeOutInSeconds = conf.GetDuration("Warehouse.clickhouse.execTimeOutInSeconds", 600, time.Second) + ch.config.commitTimeOutInSeconds = conf.GetDuration("Warehouse.clickhouse.commitTimeOutInSeconds", 600, time.Second) + ch.config.loadTableFailureRetries = conf.GetInt("Warehouse.clickhouse.loadTableFailureRetries", 3) + ch.config.numWorkersDownloadLoadFiles = conf.GetInt("Warehouse.clickhouse.numWorkersDownloadLoadFiles", 8) + ch.config.s3EngineEnabledWorkspaceIDs = conf.GetStringSlice("Warehouse.clickhouse.s3EngineEnabledWorkspaceIDs", nil) + ch.config.slowQueryThreshold = conf.GetDuration("Warehouse.clickhouse.slowQueryThreshold", 5, time.Minute) + + return ch +} + +func (ch *Clickhouse) connectToClickhouse(includeDBInConn bool) (*sqlmw.DB, error) { + cred, err := ch.connectionCredentials() + if err != nil { + return nil, fmt.Errorf("could not get connection credentials: %w", err) } values := url.Values{ @@ -234,7 +261,6 @@ func (ch *Clickhouse) connectToClickhouse(cred credentials, includeDBInConn bool "write_timeout": []string{ch.config.writeTimeout}, "compress": []string{strconv.FormatBool(ch.config.compress)}, } - if includeDBInConn { values.Add("database", cred.database) } @@ -242,64 +268,38 @@ func (ch *Clickhouse) connectToClickhouse(cred credentials, includeDBInConn bool values.Add("timeout", fmt.Sprintf("%d", cred.timeout/time.Second)) } - dsn.RawQuery = values.Encode() - - var ( - err error - db *sql.DB - ) - - if db, err = sql.Open("clickhouse", dsn.String()); err != nil { - return nil, fmt.Errorf("clickhouse connection error : (%v)", err) + dsn := url.URL{ + Scheme: "tcp", + Host: fmt.Sprintf("%s:%s", cred.host, cred.port), + RawQuery: values.Encode(), } - return db, nil -} -func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Clickhouse { - ch := &Clickhouse{} - - ch.logger = log.Child("integrations").Child("clickhouse") - ch.stats = stat - - ch.config.queryDebugLogs = conf.GetString("Warehouse.clickhouse.queryDebugLogs", "false") - ch.config.blockSize = conf.GetString("Warehouse.clickhouse.blockSize", "1000000") - ch.config.poolSize = conf.GetString("Warehouse.clickhouse.poolSize", "100") - ch.config.readTimeout = conf.GetString("Warehouse.clickhouse.readTimeout", "300") - ch.config.writeTimeout = conf.GetString("Warehouse.clickhouse.writeTimeout", "1800") - ch.config.compress = conf.GetBool("Warehouse.clickhouse.compress", false) - ch.config.disableNullable = conf.GetBool("Warehouse.clickhouse.disableNullable", false) - ch.config.execTimeOutInSeconds = conf.GetDuration("Warehouse.clickhouse.execTimeOutInSeconds", 600, time.Second) - ch.config.commitTimeOutInSeconds = conf.GetDuration("Warehouse.clickhouse.commitTimeOutInSeconds", 600, time.Second) - ch.config.loadTableFailureRetries = conf.GetInt("Warehouse.clickhouse.loadTableFailureRetries", 3) - ch.config.numWorkersDownloadLoadFiles = conf.GetInt("Warehouse.clickhouse.numWorkersDownloadLoadFiles", 8) - ch.config.s3EngineEnabledWorkspaceIDs = conf.GetStringSlice("Warehouse.clickhouse.s3EngineEnabledWorkspaceIDs", nil) - - return ch -} - -/* -registerTLSConfig will create a global map, use different names for the different tls config. -clickhouse will access the config by mentioning the key in connection string -*/ -func registerTLSConfig(key, certificate string) { - tlsConfig := &tls.Config{} // skipcq: GO-S1020 - caCert := []byte(certificate) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - tlsConfig.RootCAs = caCertPool - _ = clickhouse.RegisterTLSConfig(key, tlsConfig) + db, err := sql.Open("clickhouse", dsn.String()) + if err != nil { + return nil, fmt.Errorf("opening connection: %w", err) + } + + middleware := sqlmw.New( + db, + sqlmw.WithStats(ch.stats), + sqlmw.WithLogger(ch.logger), + sqlmw.WithKeyAndValues( + logfield.SourceID, ch.Warehouse.Source.ID, + logfield.SourceType, ch.Warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, ch.Warehouse.Destination.ID, + logfield.DestinationType, ch.Warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, ch.Warehouse.WorkspaceID, + logfield.Schema, ch.Namespace, + ), + sqlmw.WithSlowQueryThreshold(ch.config.slowQueryThreshold), + ) + return middleware, nil } -// getConnectionCredentials gives clickhouse credentials -func (ch *Clickhouse) getConnectionCredentials() credentials { - tlsName := "" - certificate := warehouseutils.GetConfigValue(caCertificate, ch.Warehouse) - if strings.TrimSpace(certificate) != "" { - // each destination will have separate tls config, hence using destination id as tlsName - tlsName = ch.Warehouse.Destination.ID - registerTLSConfig(tlsName, certificate) - } - credentials := credentials{ +// connectionCredentials returns the credentials for connecting to clickhouse +// Each destination will have separate tls config, hence using destination id as tlsName +func (ch *Clickhouse) connectionCredentials() (*credentials, error) { + credentials := &credentials{ host: warehouseutils.GetConfigValue(host, ch.Warehouse), database: warehouseutils.GetConfigValue(dbName, ch.Warehouse), user: warehouseutils.GetConfigValue(user, ch.Warehouse), @@ -307,10 +307,29 @@ func (ch *Clickhouse) getConnectionCredentials() credentials { port: warehouseutils.GetConfigValue(port, ch.Warehouse), secure: warehouseutils.GetConfigValueBoolString(secure, ch.Warehouse), skipVerify: warehouseutils.GetConfigValueBoolString(skipVerify, ch.Warehouse), - tlsConfig: tlsName, timeout: ch.connectTimeout, } - return credentials + + certificate := warehouseutils.GetConfigValue(caCertificate, ch.Warehouse) + if strings.TrimSpace(certificate) != "" { + if err := registerTLSConfig(ch.Warehouse.Destination.ID, certificate); err != nil { + return nil, fmt.Errorf("registering tls config: %w", err) + } + + credentials.tlsConfig = ch.Warehouse.Destination.ID + } + return credentials, nil +} + +// registerTLSConfig will create a global map, use different names for the different tls config. +// clickhouse will access the config by mentioning the key in connection string +func registerTLSConfig(key, certificate string) error { + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM([]byte(certificate)) + + return clickhouse.RegisterTLSConfig(key, &tls.Config{ + RootCAs: caCertPool, + }) } // ColumnsWithDataTypes creates columns and its datatype into sql format for creating table @@ -597,7 +616,7 @@ func (ch *Clickhouse) loadTablesFromFilesNamesWithRetry(ctx context.Context, tab ch.logger.Debugf("%s LoadTablesFromFilesNamesWithRetry Started", ch.GetLogIdentifier(tableName)) defer ch.logger.Debugf("%s LoadTablesFromFilesNamesWithRetry Completed", ch.GetLogIdentifier(tableName)) - var txn *sql.Tx + var txn *sqlmw.Tx var err error onError := func(err error) { @@ -753,34 +772,6 @@ func (ch *Clickhouse) schemaExists(ctx context.Context, schemaName string) (exis return } -// createSchema creates a database in clickhouse -func (ch *Clickhouse) createSchema(ctx context.Context) (err error) { - var schemaExists bool - schemaExists, err = ch.schemaExists(ctx, ch.Namespace) - if err != nil { - ch.logger.Errorf("CH: Error checking if database: %s exists: %v", ch.Namespace, err) - return err - } - if schemaExists { - ch.logger.Infof("CH: Skipping creating database: %s since it already exists", ch.Namespace) - return - } - dbHandle, err := ch.connectToClickhouse(ch.getConnectionCredentials(), false) - if err != nil { - return err - } - defer func() { _ = dbHandle.Close() }() - cluster := warehouseutils.GetConfigValue(Cluster, ch.Warehouse) - clusterClause := "" - if len(strings.TrimSpace(cluster)) > 0 { - clusterClause = fmt.Sprintf(`ON CLUSTER %q`, cluster) - } - sqlStatement := fmt.Sprintf(`CREATE DATABASE IF NOT EXISTS %q %s`, ch.Namespace, clusterClause) - ch.logger.Infof("CH: Creating database in clickhouse for ch:%s : %v", ch.Warehouse.Destination.ID, sqlStatement) - _, err = dbHandle.ExecContext(ctx, sqlStatement) - return -} - /* createUsersTable creates a user's table with engine AggregatingMergeTree, this lets us choose aggregation logic before merging records with same user id. @@ -792,7 +783,7 @@ func (ch *Clickhouse) createUsersTable(ctx context.Context, name string, columns clusterClause := "" engine := "AggregatingMergeTree" engineOptions := "" - cluster := warehouseutils.GetConfigValue(Cluster, ch.Warehouse) + cluster := warehouseutils.GetConfigValue(cluster, ch.Warehouse) if len(strings.TrimSpace(cluster)) > 0 { clusterClause = fmt.Sprintf(`ON CLUSTER %q`, cluster) engine = fmt.Sprintf(`%s%s`, "Replicated", engine) @@ -834,7 +825,7 @@ func (ch *Clickhouse) CreateTable(ctx context.Context, tableName string, columns clusterClause := "" engine := "ReplacingMergeTree" engineOptions := "" - cluster := warehouseutils.GetConfigValue(Cluster, ch.Warehouse) + cluster := warehouseutils.GetConfigValue(cluster, ch.Warehouse) if len(strings.TrimSpace(cluster)) > 0 { clusterClause = fmt.Sprintf(`ON CLUSTER %q`, cluster) engine = fmt.Sprintf(`%s%s`, "Replicated", engine) @@ -858,35 +849,23 @@ func (ch *Clickhouse) CreateTable(ctx context.Context, tableName string, columns } func (ch *Clickhouse) DropTable(ctx context.Context, tableName string) (err error) { - cluster := warehouseutils.GetConfigValue(Cluster, ch.Warehouse) - clusterClause := "" - if len(strings.TrimSpace(cluster)) > 0 { - clusterClause = fmt.Sprintf(`ON CLUSTER %q`, cluster) - } - sqlStatement := fmt.Sprintf(`DROP TABLE %q.%q %s `, ch.Warehouse.Namespace, tableName, clusterClause) + sqlStatement := fmt.Sprintf(`DROP TABLE %q.%q %s `, ch.Warehouse.Namespace, tableName, ch.clusterClause()) _, err = ch.DB.ExecContext(ctx, sqlStatement) return } func (ch *Clickhouse) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( - query string - queryBuilder strings.Builder - cluster string - clusterClause string + query string + queryBuilder strings.Builder ) - cluster = warehouseutils.GetConfigValue(Cluster, ch.Warehouse) - if len(strings.TrimSpace(cluster)) > 0 { - clusterClause = fmt.Sprintf(`ON CLUSTER %q`, cluster) - } - queryBuilder.WriteString(fmt.Sprintf(` ALTER TABLE %q.%q %s`, ch.Namespace, tableName, - clusterClause, + ch.clusterClause(), )) for _, columnInfo := range columnsInfo { @@ -907,12 +886,48 @@ func (ch *Clickhouse) AddColumns(ctx context.Context, tableName string, columnsI return } -func (ch *Clickhouse) CreateSchema(ctx context.Context) (err error) { +func (ch *Clickhouse) CreateSchema(ctx context.Context) error { if len(ch.Uploader.GetSchemaInWarehouse()) > 0 { return nil } - err = ch.createSchema(ctx) - return err + + if schemaExists, err := ch.schemaExists(ctx, ch.Namespace); err != nil { + return fmt.Errorf("checking if database: %s exists: %v", ch.Namespace, err) + } else if schemaExists { + return nil + } + + db, err := ch.connectToClickhouse(false) + if err != nil { + return fmt.Errorf("connecting to clickhouse: %v", err) + } + defer func() { _ = db.Close() }() + + ch.logger.Infow("Creating schema", append(ch.defaultLogFields(), "clusterClause", ch.clusterClause())) + + query := fmt.Sprintf(`CREATE DATABASE IF NOT EXISTS %q %s`, ch.Namespace, ch.clusterClause()) + if _, err = db.ExecContext(ctx, query); err != nil { + return fmt.Errorf("creating database: %v", err) + } + return nil +} + +func (ch *Clickhouse) clusterClause() string { + if cluster := warehouseutils.GetConfigValue(cluster, ch.Warehouse); len(strings.TrimSpace(cluster)) > 0 { + return fmt.Sprintf(`ON CLUSTER %q`, cluster) + } + return "" +} + +func (ch *Clickhouse) defaultLogFields() []any { + return []any{ + logfield.SourceID, ch.Warehouse.Source.ID, + logfield.SourceType, ch.Warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, ch.Warehouse.Destination.ID, + logfield.DestinationType, ch.Warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, ch.Warehouse.WorkspaceID, + logfield.Namespace, ch.Namespace, + } } func (*Clickhouse) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { @@ -939,7 +954,9 @@ func (ch *Clickhouse) Setup(_ context.Context, warehouse model.Warehouse, upload ch.ObjectStorage = warehouseutils.ObjectStorageType(warehouseutils.CLICKHOUSE, warehouse.Destination.Config, ch.Uploader.UseRudderStorage()) ch.LoadFileDownloader = downloader.NewDownloader(&warehouse, uploader, ch.config.numWorkersDownloadLoadFiles) - ch.DB, err = ch.connectToClickhouse(ch.getConnectionCredentials(), true) + if ch.DB, err = ch.connectToClickhouse(true); err != nil { + return fmt.Errorf("connecting to clickhouse: %w", err) + } return err } @@ -1072,12 +1089,13 @@ func (ch *Clickhouse) Connect(_ context.Context, warehouse model.Warehouse) (cli warehouse.Destination.Config, misc.IsConfiguredToUseRudderObjectStorage(ch.Warehouse.Destination.Config), ) - dbHandle, err := ch.connectToClickhouse(ch.getConnectionCredentials(), true) + + db, err := ch.connectToClickhouse(true) if err != nil { - return client.Client{}, err + return client.Client{}, fmt.Errorf("connecting to clickhouse: %w", err) } - return client.Client{Type: client.SQLClient, SQL: dbHandle}, err + return client.Client{Type: client.SQLClient, SQL: db.DB}, err } func (ch *Clickhouse) GetLogIdentifier(args ...string) string { diff --git a/warehouse/integrations/manager/manager.go b/warehouse/integrations/manager/manager.go index c3dd7383294..fc3016d55a1 100644 --- a/warehouse/integrations/manager/manager.go +++ b/warehouse/integrations/manager/manager.go @@ -73,9 +73,9 @@ func New(destType string, conf *config.Config, logger logger.Logger, stats stats case warehouseutils.CLICKHOUSE: return clickhouse.New(conf, logger, stats), nil case warehouseutils.MSSQL: - return mssql.New(conf, logger), nil + return mssql.New(conf, logger, stats), nil case warehouseutils.AzureSynapse: - return azuresynapse.New(conf, logger), nil + return azuresynapse.New(conf, logger, stats), nil case warehouseutils.S3Datalake, warehouseutils.GCSDatalake, warehouseutils.AzureDatalake: return datalake.New(logger), nil case warehouseutils.DELTALAKE: @@ -98,9 +98,9 @@ func NewWarehouseOperations(destType string, conf *config.Config, logger logger. case warehouseutils.CLICKHOUSE: return clickhouse.New(conf, logger, stats), nil case warehouseutils.MSSQL: - return mssql.New(conf, logger), nil + return mssql.New(conf, logger, stats), nil case warehouseutils.AzureSynapse: - return azuresynapse.New(conf, logger), nil + return azuresynapse.New(conf, logger, stats), nil case warehouseutils.S3Datalake, warehouseutils.GCSDatalake, warehouseutils.AzureDatalake: return datalake.New(logger), nil case warehouseutils.DELTALAKE: diff --git a/warehouse/integrations/mssql/mssql.go b/warehouse/integrations/mssql/mssql.go index 4553bb16538..6f295137124 100644 --- a/warehouse/integrations/mssql/mssql.go +++ b/warehouse/integrations/mssql/mssql.go @@ -18,6 +18,10 @@ import ( "unicode/utf16" "unicode/utf8" + "github.com/rudderlabs/rudder-go-kit/stats" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/logfield" + "github.com/rudderlabs/rudder-server/warehouse/internal/service/loadfiles/downloader" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -81,18 +85,21 @@ var mssqlDataTypesMapToRudder = map[string]string{ } type MSSQL struct { - DB *sql.DB + DB *sqlmw.DB Namespace string ObjectStorage string Warehouse model.Warehouse Uploader warehouseutils.Uploader connectTimeout time.Duration - logger logger.Logger LoadFileDownLoader downloader.Downloader + stats stats.Stats + logger logger.Logger + config struct { enableDeleteByJobs bool numWorkersDownloadLoadFiles int + slowQueryThreshold time.Duration } } @@ -125,38 +132,39 @@ var errorsMappings = []model.JobError{ }, } -func New(conf *config.Config, log logger.Logger) *MSSQL { - ms := &MSSQL{} - - ms.logger = log.Child("integrations").Child("mssql") - +func New(conf *config.Config, log logger.Logger, stats stats.Stats) *MSSQL { + ms := &MSSQL{ + stats: stats, + logger: log.Child("integrations").Child("mssql"), + } ms.config.enableDeleteByJobs = conf.GetBool("Warehouse.mssql.enableDeleteByJobs", false) ms.config.numWorkersDownloadLoadFiles = conf.GetInt("Warehouse.mssql.numWorkersDownloadLoadFiles", 1) + ms.config.slowQueryThreshold = conf.GetDuration("Warehouse.mssql.slowQueryThreshold", 5, time.Minute) return ms } -func Connect(cred credentials) (*sql.DB, error) { - // Create connection string - // url := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%s;database=%s;encrypt=%s;TrustServerCertificate=true", cred.host, cred.user, cred.password, cred.port, cred.dbName, cred.sslMode) - // Encryption options : disable, false, true. https://github.com/denisenkom/go-mssqldb - // TrustServerCertificate=true ; all options(disable, false, true) work with this - // if rds.forcessl=1; disable option doesn't work. true, false works alongside TrustServerCertificate=true - // https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/SQLServer.Concepts.General.SSL.Using.html - // more combination explanations here: https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/connection-string-keywords-and-data-source-names-dsns?view=sql-server-ver15 +// connect to mssql database +// if TrustServerCertificate is set to true, all options(disable, false, true) works. +// if forceSSL is set to 1, disable option doesn't work. +// If forceSSL is set to true or false, it works alongside with TrustServerCertificate=true +// more about combinations in here: https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/connection-string-keywords-and-data-source-names-dsns?view=sql-server-ver15 +func (ms *MSSQL) connect() (*sqlmw.DB, error) { + cred := ms.connectionCredentials() + + port, err := strconv.Atoi(cred.port) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + query := url.Values{} query.Add("database", cred.database) query.Add("encrypt", cred.sslMode) - + query.Add("TrustServerCertificate", "true") if cred.timeout > 0 { query.Add("dial timeout", fmt.Sprintf("%d", cred.timeout/time.Second)) } - query.Add("TrustServerCertificate", "true") - port, err := strconv.Atoi(cred.port) - if err != nil { - return nil, fmt.Errorf("invalid port: %w", err) - } connUrl := &url.URL{ Scheme: "sqlserver", User: url.UserPassword(cred.user, cred.password), @@ -164,17 +172,30 @@ func Connect(cred credentials) (*sql.DB, error) { RawQuery: query.Encode(), } - var db *sql.DB - - if db, err = sql.Open("sqlserver", connUrl.String()); err != nil { - return nil, fmt.Errorf("opening connection to mssql server: %w", err) - } - - return db, nil + db, err := sql.Open("sqlserver", connUrl.String()) + if err != nil { + return nil, fmt.Errorf("opening connection: %w", err) + } + + middleware := sqlmw.New( + db, + sqlmw.WithStats(ms.stats), + sqlmw.WithLogger(ms.logger), + sqlmw.WithKeyAndValues( + logfield.SourceID, ms.Warehouse.Source.ID, + logfield.SourceType, ms.Warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, ms.Warehouse.Destination.ID, + logfield.DestinationType, ms.Warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, ms.Warehouse.WorkspaceID, + logfield.Schema, ms.Namespace, + ), + sqlmw.WithSlowQueryThreshold(ms.config.slowQueryThreshold), + ) + return middleware, nil } -func (ms *MSSQL) getConnectionCredentials() credentials { - creds := credentials{ +func (ms *MSSQL) connectionCredentials() *credentials { + return &credentials{ host: warehouseutils.GetConfigValue(host, ms.Warehouse), database: warehouseutils.GetConfigValue(dbName, ms.Warehouse), user: warehouseutils.GetConfigValue(user, ms.Warehouse), @@ -183,8 +204,6 @@ func (ms *MSSQL) getConnectionCredentials() credentials { sslMode: warehouseutils.GetConfigValue(sslMode, ms.Warehouse), timeout: ms.connectTimeout, } - - return creds } func ColumnsWithDataTypes(columns model.TableSchema, prefix string) string { @@ -702,7 +721,9 @@ func (ms *MSSQL) Setup(_ context.Context, warehouse model.Warehouse, uploader wa ms.ObjectStorage = warehouseutils.ObjectStorageType(warehouseutils.MSSQL, warehouse.Destination.Config, ms.Uploader.UseRudderStorage()) ms.LoadFileDownLoader = downloader.NewDownloader(&warehouse, uploader, ms.config.numWorkersDownloadLoadFiles) - ms.DB, err = Connect(ms.getConnectionCredentials()) + if ms.DB, err = ms.connect(); err != nil { + return fmt.Errorf("connecting to mssql: %w", err) + } return err } @@ -863,12 +884,13 @@ func (ms *MSSQL) Connect(_ context.Context, warehouse model.Warehouse) (client.C warehouse.Destination.Config, misc.IsConfiguredToUseRudderObjectStorage(ms.Warehouse.Destination.Config), ) - dbHandle, err := Connect(ms.getConnectionCredentials()) + + db, err := ms.connect() if err != nil { - return client.Client{}, err + return client.Client{}, fmt.Errorf("connecting to mssql: %w", err) } - return client.Client{Type: client.SQLClient, SQL: dbHandle}, err + return client.Client{Type: client.SQLClient, SQL: db.DB}, err } func (ms *MSSQL) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) {