From cab70461b4b5aaa03bc512aff333900e191819a2 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Thu, 23 Feb 2023 12:48:51 +0530 Subject: [PATCH 1/7] added role support for snowflake --- warehouse/integrations/manager/manager.go | 10 +- warehouse/integrations/snowflake/snowflake.go | 300 +++++++++--------- .../integrations/snowflake/snowflake_test.go | 8 +- warehouse/integrations/testhelper/setup.go | 2 +- 4 files changed, 166 insertions(+), 154 deletions(-) diff --git a/warehouse/integrations/manager/manager.go b/warehouse/integrations/manager/manager.go index c0fd2334dc..9e53b46c15 100644 --- a/warehouse/integrations/manager/manager.go +++ b/warehouse/integrations/manager/manager.go @@ -67,8 +67,9 @@ func New(destType string) (Manager, error) { var bq bigquery.HandleT return &bq, nil case warehouseutils.SNOWFLAKE: - var sf snowflake.HandleT - return &sf, nil + sf := snowflake.NewSnowflake() + snowflake.WithConfig(sf, config.Default) + return sf, nil case warehouseutils.POSTGRES: pg := postgres.NewHandle() postgres.WithConfig(pg, config.Default) @@ -105,8 +106,9 @@ func NewWarehouseOperations(destType string) (WarehouseOperations, error) { var bq bigquery.HandleT return &bq, nil case warehouseutils.SNOWFLAKE: - var sf snowflake.HandleT - return &sf, nil + sf := snowflake.NewSnowflake() + snowflake.WithConfig(sf, config.Default) + return sf, nil case warehouseutils.POSTGRES: pg := postgres.NewHandle() postgres.WithConfig(pg, config.Default) diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index 3e445b9b93..bdf509eec2 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -26,40 +26,20 @@ const ( tableNameLimit = 127 ) -var ( - pkgLogger logger.Logger - enableDeleteByJobs bool -) - -func Init() { - loadConfig() - pkgLogger = logger.NewLogger().Child("warehouse").Child("snowflake") -} - -func loadConfig() { - config.RegisterBoolConfigVariable(false, &enableDeleteByJobs, true, "Warehouse.snowflake.enableDeleteByJobs") -} - -type HandleT struct { - DB *sql.DB - Namespace string - CloudProvider string - ObjectStorage string - Warehouse warehouseutils.Warehouse - Uploader warehouseutils.UploaderI - ConnectTimeout time.Duration -} - // String constants for snowflake destination config const ( StorageIntegration = "storageIntegration" - SFAccount = "account" - SFWarehouse = "warehouse" - SFDbName = "database" - SFUserName = "user" - SFPassword = "password" + Account = "account" + Warehouse = "warehouse" + Database = "database" + User = "user" + Role = "role" + Password = "password" + Application = "Rudderstack" ) +var pkgLogger logger.Logger + var dataTypesMap = map[string]string{ "boolean": "boolean", "int": "number", @@ -177,11 +157,53 @@ var errorsMappings = []model.JobError{ }, } -type tableLoadRespT struct { +type Credentials struct { + Account string + Warehouse string + Database string + User string + Role string + Password string + schemaName string + timeout time.Duration +} + +type tableLoadResp struct { dbHandle *sql.DB stagingTable string } +type optionalCreds struct { + schemaName string +} + +type Snowflake struct { + DB *sql.DB + Namespace string + CloudProvider string + ObjectStorage string + Warehouse warehouseutils.Warehouse + Uploader warehouseutils.UploaderI + ConnectTimeout time.Duration + Logger logger.Logger + + EnableDeleteByJobs bool +} + +func Init() { + pkgLogger = logger.NewLogger().Child("warehouse").Child("snowflake") +} + +func NewSnowflake() *Snowflake { + return &Snowflake{ + Logger: pkgLogger, + } +} + +func WithConfig(h *Snowflake, config *config.Config) { + h.EnableDeleteByJobs = config.GetBool("Warehouse.snowflake.enableDeleteByJobs", false) +} + func ColumnsWithDataTypes(columns map[string]string, prefix string) string { var arr []string for name, dataType := range columns { @@ -191,21 +213,21 @@ func ColumnsWithDataTypes(columns map[string]string, prefix string) string { } // schemaIdentifier returns [DATABASE_NAME].[NAMESPACE] format to access the schema directly. -func (sf *HandleT) schemaIdentifier() string { +func (sf *Snowflake) schemaIdentifier() string { return fmt.Sprintf(`"%s"`, sf.Namespace, ) } -func (sf *HandleT) createTable(tableName string, columns map[string]string) (err error) { +func (sf *Snowflake) createTable(tableName string, columns map[string]string) (err error) { schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s."%s" ( %v )`, schemaIdentifier, tableName, ColumnsWithDataTypes(columns, "")) - pkgLogger.Infof("Creating table in snowflake for SF:%s : %v", sf.Warehouse.Destination.ID, sqlStatement) + sf.Logger.Infof("Creating table in snowflake for SF:%s : %v", sf.Warehouse.Destination.ID, sqlStatement) _, err = sf.DB.Exec(sqlStatement) return } -func (sf *HandleT) tableExists(tableName string) (exists bool, err error) { +func (sf *Snowflake) tableExists(tableName string) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = '%s' @@ -215,7 +237,7 @@ func (sf *HandleT) tableExists(tableName string) (exists bool, err error) { return } -func (sf *HandleT) columnExists(columnName, tableName string) (exists bool, err error) { +func (sf *Snowflake) columnExists(columnName, tableName string) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_schema = '%s' @@ -226,7 +248,7 @@ func (sf *HandleT) columnExists(columnName, tableName string) (exists bool, err return } -func (sf *HandleT) schemaExists() (exists bool, err error) { +func (sf *Snowflake) schemaExists() (exists bool, err error) { sqlStatement := fmt.Sprintf("SELECT EXISTS ( SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s' )", sf.Namespace) r := sf.DB.QueryRow(sqlStatement) err = r.Scan(&exists) @@ -237,10 +259,10 @@ func (sf *HandleT) schemaExists() (exists bool, err error) { return } -func (sf *HandleT) createSchema() (err error) { +func (sf *Snowflake) createSchema() (err error) { schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaIdentifier) - pkgLogger.Infof("SF: Creating schema name in snowflake for %s:%s : %v", sf.Warehouse.Namespace, sf.Warehouse.Destination.ID, sqlStatement) + sf.Logger.Infof("SF: Creating schema name in snowflake for %s:%s : %v", sf.Warehouse.Namespace, sf.Warehouse.Destination.ID, sqlStatement) _, err = sf.DB.Exec(sqlStatement) return } @@ -258,7 +280,7 @@ func checkAndIgnoreAlreadyExistError(err error) bool { return true } -func (sf *HandleT) authString() string { +func (sf *Snowflake) authString() string { var auth string if misc.IsConfiguredToUseRudderObjectStorage(sf.Warehouse.Destination.Config) || (sf.CloudProvider == "AWS" && warehouseutils.GetConfigValue(StorageIntegration, sf.Warehouse) == "") { tempAccessKeyId, tempSecretAccessKey, token, _ := warehouseutils.GetTemporaryS3Cred(&sf.Warehouse.Destination) @@ -269,8 +291,8 @@ func (sf *HandleT) authString() string { return auth } -func (sf *HandleT) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) (err error) { - pkgLogger.Infof("SF: Cleaning up the following tables in snowflake for SF:%s : %v", tableNames) +func (sf *Snowflake) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) (err error) { + sf.Logger.Infof("SF: Cleaning up the following tables in snowflake for SF:%s : %v", tableNames) for _, tb := range tableNames { sqlStatement := fmt.Sprintf(` DELETE FROM @@ -289,13 +311,13 @@ func (sf *HandleT) DeleteBy(tableNames []string, params warehouseutils.DeleteByP params.StartTime, ) - pkgLogger.Infof("SF: Deleting rows in table in snowflake for SF:%s", sf.Warehouse.Destination.ID) - pkgLogger.Debugf("SF: Executing the sql statement %v", sqlStatement) + sf.Logger.Infof("SF: Deleting rows in table in snowflake for SF:%s", sf.Warehouse.Destination.ID) + sf.Logger.Debugf("SF: Executing the sql statement %v", sqlStatement) - if enableDeleteByJobs { + if sf.EnableDeleteByJobs { _, err = sf.DB.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("Error %s", err) + sf.Logger.Errorf("Error %s", err) return err } } @@ -303,13 +325,13 @@ func (sf *HandleT) DeleteBy(tableNames []string, params warehouseutils.DeleteByP return nil } -func (sf *HandleT) loadTable(tableName string, tableSchemaInUpload warehouseutils.TableSchemaT, dbHandle *sql.DB, skipClosingDBSession bool) (tableLoadResp tableLoadRespT, err error) { - pkgLogger.Infof("SF: Starting load for table:%s\n", tableName) +func (sf *Snowflake) loadTable(tableName string, tableSchemaInUpload warehouseutils.TableSchemaT, dbHandle *sql.DB, skipClosingDBSession bool) (tableLoadResp tableLoadResp, err error) { + sf.Logger.Infof("SF: Starting load for table:%s\n", tableName) if dbHandle == nil { - dbHandle, err = Connect(sf.getConnectionCredentials(OptionalCredsT{schemaName: sf.Namespace})) + dbHandle, err = Connect(sf.getConnectionCredentials(optionalCreds{schemaName: sf.Namespace})) if err != nil { - pkgLogger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", tableName, err) + sf.Logger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", tableName, err) return } } @@ -328,10 +350,10 @@ func (sf *HandleT) loadTable(tableName string, tableSchemaInUpload warehouseutil stagingTableName := warehouseutils.StagingTableName(provider, tableName, tableNameLimit) sqlStatement := fmt.Sprintf(`CREATE TEMPORARY TABLE %[1]s."%[2]s" LIKE %[1]s."%[3]s"`, schemaIdentifier, stagingTableName, tableName) - pkgLogger.Debugf("SF: Creating temporary table for table:%s at %s\n", tableName, sqlStatement) + sf.Logger.Debugf("SF: Creating temporary table for table:%s at %s\n", tableName, sqlStatement) _, err = dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error creating temporary table for table:%s: %v\n", tableName, err) + sf.Logger.Errorf("SF: Error creating temporary table for table:%s: %v\n", tableName, err) return } tableLoadResp.stagingTable = stagingTableName @@ -352,12 +374,12 @@ func (sf *HandleT) loadTable(tableName string, tableSchemaInUpload warehouseutil "AWS_TOKEN='[^']*'": "AWS_TOKEN='***'", }) if regexErr == nil { - pkgLogger.Infof("SF: Running COPY command for table:%s at %s\n", tableName, sanitisedSQLStmt) + sf.Logger.Infof("SF: Running COPY command for table:%s at %s\n", tableName, sanitisedSQLStmt) } _, err = dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error running COPY command: %v\n", err) + sf.Logger.Errorf("SF: Error running COPY command: %v\n", err) return } @@ -409,30 +431,30 @@ func (sf *HandleT) loadTable(tableName string, tableSchemaInUpload warehouseutil INSERT (%[4]s) VALUES (%[5]s)`, tableName, stagingTableName, primaryKey, sortedColumnNames, stagingColumnNames, additionalJoinClause, partitionKey, schemaIdentifier) } - pkgLogger.Infof("SF: Dedup records for table:%s using staging table: %s\n", tableName, sqlStatement) + sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", tableName, sqlStatement) _, err = dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error running MERGE for dedup: %v\n", err) + sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) return } - pkgLogger.Infof("SF: Complete load for table:%s\n", tableName) + sf.Logger.Infof("SF: Complete load for table:%s\n", tableName) return } -func (sf *HandleT) LoadIdentityMergeRulesTable() (err error) { - pkgLogger.Infof("SF: Starting load for table:%s\n", identityMergeRulesTable) +func (sf *Snowflake) LoadIdentityMergeRulesTable() (err error) { + sf.Logger.Infof("SF: Starting load for table:%s\n", identityMergeRulesTable) - pkgLogger.Infof("SF: Fetching load file location for %s", identityMergeRulesTable) + sf.Logger.Infof("SF: Fetching load file location for %s", identityMergeRulesTable) var loadFile warehouseutils.LoadFileT loadFile, err = sf.Uploader.GetSingleLoadFile(identityMergeRulesTable) if err != nil { return err } - dbHandle, err := Connect(sf.getConnectionCredentials(OptionalCredsT{schemaName: sf.Namespace})) + dbHandle, err := Connect(sf.getConnectionCredentials(optionalCreds{schemaName: sf.Namespace})) if err != nil { - pkgLogger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", identityMergeRulesTable, err) + sf.Logger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", identityMergeRulesTable, err) return } @@ -448,21 +470,21 @@ func (sf *HandleT) LoadIdentityMergeRulesTable() (err error) { "AWS_TOKEN='[^']*'": "AWS_TOKEN='***'", }) if regexErr == nil { - pkgLogger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMergeRulesTable, sanitisedSQLStmt) + sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMergeRulesTable, sanitisedSQLStmt) } _, err = dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error running MERGE for dedup: %v\n", err) + sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) return } - pkgLogger.Infof("SF: Complete load for table:%s\n", identityMergeRulesTable) + sf.Logger.Infof("SF: Complete load for table:%s\n", identityMergeRulesTable) return } -func (sf *HandleT) LoadIdentityMappingsTable() (err error) { - pkgLogger.Infof("SF: Starting load for table:%s\n", identityMappingsTable) - pkgLogger.Infof("SF: Fetching load file location for %s", identityMappingsTable) +func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { + sf.Logger.Infof("SF: Starting load for table:%s\n", identityMappingsTable) + sf.Logger.Infof("SF: Fetching load file location for %s", identityMappingsTable) var loadFile warehouseutils.LoadFileT loadFile, err = sf.Uploader.GetSingleLoadFile(identityMappingsTable) @@ -470,9 +492,9 @@ func (sf *HandleT) LoadIdentityMappingsTable() (err error) { return err } - dbHandle, err := Connect(sf.getConnectionCredentials(OptionalCredsT{schemaName: sf.Namespace})) + dbHandle, err := Connect(sf.getConnectionCredentials(optionalCreds{schemaName: sf.Namespace})) if err != nil { - pkgLogger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", identityMappingsTable, err) + sf.Logger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", identityMappingsTable, err) return } @@ -480,18 +502,18 @@ func (sf *HandleT) LoadIdentityMappingsTable() (err error) { stagingTableName := warehouseutils.StagingTableName(provider, identityMappingsTable, tableNameLimit) sqlStatement := fmt.Sprintf(`CREATE TEMPORARY TABLE %[1]s."%[2]s" LIKE %[1]s."%[3]s"`, schemaIdentifier, stagingTableName, identityMappingsTable) - pkgLogger.Infof("SF: Creating temporary table for table:%s at %s\n", identityMappingsTable, sqlStatement) + sf.Logger.Infof("SF: Creating temporary table for table:%s at %s\n", identityMappingsTable, sqlStatement) _, err = dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error creating temporary table for table:%s: %v\n", identityMappingsTable, err) + sf.Logger.Errorf("SF: Error creating temporary table for table:%s: %v\n", identityMappingsTable, err) return } sqlStatement = fmt.Sprintf(`ALTER TABLE %s."%s" ADD COLUMN "ID" int AUTOINCREMENT start 1 increment 1`, schemaIdentifier, stagingTableName) - pkgLogger.Infof("SF: Adding autoincrement column for table:%s at %s\n", stagingTableName, sqlStatement) + sf.Logger.Infof("SF: Adding autoincrement column for table:%s at %s\n", stagingTableName, sqlStatement) _, err = dbHandle.Exec(sqlStatement) if err != nil && !checkAndIgnoreAlreadyExistError(err) { - pkgLogger.Errorf("SF: Error adding autoincrement column for table:%s: %v\n", stagingTableName, err) + sf.Logger.Errorf("SF: Error adding autoincrement column for table:%s: %v\n", stagingTableName, err) return } @@ -499,10 +521,10 @@ func (sf *HandleT) LoadIdentityMappingsTable() (err error) { sqlStatement = fmt.Sprintf(`COPY INTO %v("MERGE_PROPERTY_TYPE", "MERGE_PROPERTY_VALUE", "RUDDER_ID", "UPDATED_AT") FROM '%v' %s PATTERN = '.*\.csv\.gz' FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s."%s"`, schemaIdentifier, stagingTableName), loadLocation, sf.authString()) - pkgLogger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMappingsTable, sqlStatement) + sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMappingsTable, sqlStatement) _, err = dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error running MERGE for dedup: %v\n", err) + sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) return } @@ -517,23 +539,23 @@ func (sf *HandleT) LoadIdentityMappingsTable() (err error) { UPDATE SET original."RUDDER_ID" = staging."RUDDER_ID", original."UPDATED_AT" = staging."UPDATED_AT" WHEN NOT MATCHED THEN INSERT ("MERGE_PROPERTY_TYPE", "MERGE_PROPERTY_VALUE", "RUDDER_ID", "UPDATED_AT") VALUES (staging."MERGE_PROPERTY_TYPE", staging."MERGE_PROPERTY_VALUE", staging."RUDDER_ID", staging."UPDATED_AT")`, identityMappingsTable, stagingTableName, schemaIdentifier) - pkgLogger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMappingsTable, sqlStatement) + sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMappingsTable, sqlStatement) _, err = dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error running MERGE for dedup: %v\n", err) + sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) return } - pkgLogger.Infof("SF: Complete load for table:%s\n", identityMappingsTable) + sf.Logger.Infof("SF: Complete load for table:%s\n", identityMappingsTable) return } -func (sf *HandleT) loadUserTables() (errorMap map[string]error) { +func (sf *Snowflake) loadUserTables() (errorMap map[string]error) { identifyColMap := sf.Uploader.GetTableSchemaInUpload(identifiesTable) if len(identifyColMap) == 0 { return errorMap } errorMap = map[string]error{identifiesTable: nil} - pkgLogger.Infof("SF: Starting load for identifies and users tables\n") + sf.Logger.Infof("SF: Starting load for identifies and users tables\n") resp, err := sf.loadTable(identifiesTable, sf.Uploader.GetTableSchemaInUpload(identifiesTable), nil, true) if err != nil { @@ -587,10 +609,10 @@ func (sf *HandleT) loadUserTables() (errorMap map[string]error) { strings.Join(userColNames, ","), // 6 strings.Join(identifyColNames, ","), // 7 ) - pkgLogger.Infof("SF: Creating staging table for users: %s\n", sqlStatement) + sf.Logger.Infof("SF: Creating staging table for users: %s\n", sqlStatement) _, err = resp.dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error creating temporary table for table:%s: %v\n", usersTable, err) + sf.Logger.Errorf("SF: Error creating temporary table for table:%s: %v\n", usersTable, err) errorMap[usersTable] = err return errorMap } @@ -617,36 +639,27 @@ func (sf *HandleT) loadUserTables() (errorMap map[string]error) { UPDATE SET %[5]s WHEN NOT MATCHED THEN INSERT (%[3]s) VALUES (%[6]s)`, usersTable, stagingTableName, columnNamesStr, primaryKey, columnsWithValues, stagingColumnValues, schemaIdentifier) - pkgLogger.Infof("SF: Dedup records for table:%s using staging table: %s\n", usersTable, sqlStatement) + sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", usersTable, sqlStatement) _, err = resp.dbHandle.Exec(sqlStatement) if err != nil { - pkgLogger.Errorf("SF: Error running MERGE for dedup: %v\n", err) + sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) errorMap[usersTable] = err return errorMap } - pkgLogger.Infof("SF: Complete load for table:%s", usersTable) + sf.Logger.Infof("SF: Complete load for table:%s", usersTable) return errorMap } -type SnowflakeCredentialsT struct { - Account string - WHName string - DBName string - Username string - Password string - schemaName string - timeout time.Duration -} - -func Connect(cred SnowflakeCredentialsT) (*sql.DB, error) { +func Connect(cred Credentials) (*sql.DB, error) { urlConfig := snowflake.Config{ Account: cred.Account, - User: cred.Username, + User: cred.User, + Role: cred.Role, Password: cred.Password, - Database: cred.DBName, + Database: cred.Database, Schema: cred.schemaName, - Warehouse: cred.WHName, - Application: "Rudderstack", + Warehouse: cred.Warehouse, + Application: Application, } if cred.timeout > 0 { @@ -673,34 +686,34 @@ func Connect(cred SnowflakeCredentialsT) (*sql.DB, error) { return db, nil } -func (sf *HandleT) CreateSchema() (err error) { +func (sf *Snowflake) CreateSchema() (err error) { var schemaExists bool schemaIdentifier := sf.schemaIdentifier() schemaExists, err = sf.schemaExists() if err != nil { - pkgLogger.Errorf("SF: Error checking if schema: %s exists: %v", schemaIdentifier, err) + sf.Logger.Errorf("SF: Error checking if schema: %s exists: %v", schemaIdentifier, err) return err } if schemaExists { - pkgLogger.Infof("SF: Skipping creating schema: %s since it already exists", schemaIdentifier) + sf.Logger.Infof("SF: Skipping creating schema: %s since it already exists", schemaIdentifier) return } return sf.createSchema() } -func (sf *HandleT) CreateTable(tableName string, columnMap map[string]string) (err error) { +func (sf *Snowflake) CreateTable(tableName string, columnMap map[string]string) (err error) { return sf.createTable(tableName, columnMap) } -func (sf *HandleT) DropTable(tableName string) (err error) { +func (sf *Snowflake) DropTable(tableName string) (err error) { schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`DROP TABLE %[1]s."%[2]s"`, schemaIdentifier, tableName) - pkgLogger.Infof("SF: Dropping table in snowflake for SF:%s : %v", sf.Warehouse.Destination.ID, sqlStatement) + sf.Logger.Infof("SF: Dropping table in snowflake for SF:%s : %v", sf.Warehouse.Destination.ID, sqlStatement) _, err = sf.DB.Exec(sqlStatement) return } -func (sf *HandleT) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (sf *Snowflake) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( query string queryBuilder strings.Builder @@ -724,14 +737,14 @@ func (sf *HandleT) AddColumns(tableName string, columnsInfo []warehouseutils.Col query = strings.TrimSuffix(queryBuilder.String(), ",") query += ";" - pkgLogger.Infof("SF: Adding columns for destinationID: %s, tableName: %s with query: %v", sf.Warehouse.Destination.ID, tableName, query) + sf.Logger.Infof("SF: Adding columns for destinationID: %s, tableName: %s with query: %v", sf.Warehouse.Destination.ID, tableName, query) _, err = sf.DB.Exec(query) // Handle error in case of single column if len(columnsInfo) == 1 { if err != nil { if checkAndIgnoreAlreadyExistError(err) { - pkgLogger.Infof("SF: Column %s already exists on %s.%s \nResponse: %v", columnsInfo[0].Name, schemaIdentifier, tableName, err) + sf.Logger.Infof("SF: Column %s already exists on %s.%s \nResponse: %v", columnsInfo[0].Name, schemaIdentifier, tableName, err) err = nil } } @@ -739,12 +752,12 @@ func (sf *HandleT) AddColumns(tableName string, columnsInfo []warehouseutils.Col return } -func (*HandleT) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*Snowflake) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } // DownloadIdentityRules gets distinct combinations of anonymous_id, user_id from tables in warehouse -func (sf *HandleT) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) { +func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) { getFromTable := func(tableName string) (err error) { var exists bool exists, err = sf.tableExists(tableName) @@ -778,7 +791,7 @@ func (sf *HandleT) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) } else if hasUserID { toSelectFields = `NULL AS "ANONYMOUS_ID", "USER_ID"` } else { - pkgLogger.Infof("SF: ANONYMOUS_ID, USER_ID columns not present in table: %s", tableName) + sf.Logger.Infof("SF: ANONYMOUS_ID, USER_ID columns not present in table: %s", tableName) return nil } @@ -787,7 +800,7 @@ func (sf *HandleT) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) for { // TODO: Handle case for missing anonymous_id, user_id columns sqlStatement = fmt.Sprintf(`SELECT DISTINCT %s FROM %s."%s" LIMIT %d OFFSET %d`, toSelectFields, schemaIdentifier, tableName, batchSize, offset) - pkgLogger.Infof("SF: Downloading distinct combinations of anonymous_id, user_id: %s, totalRows: %d", sqlStatement, totalRows) + sf.Logger.Infof("SF: Downloading distinct combinations of anonymous_id, user_id: %s, totalRows: %d", sqlStatement, totalRows) var rows *sql.Rows rows, err = sf.DB.Query(sqlStatement) if err != nil { @@ -838,16 +851,16 @@ func (sf *HandleT) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) return nil } -func (*HandleT) CrashRecover(_ warehouseutils.Warehouse) (err error) { +func (*Snowflake) CrashRecover(_ warehouseutils.Warehouse) (err error) { return } -func (sf *HandleT) IsEmpty(warehouse warehouseutils.Warehouse) (empty bool, err error) { +func (sf *Snowflake) IsEmpty(warehouse warehouseutils.Warehouse) (empty bool, err error) { empty = true sf.Warehouse = warehouse sf.Namespace = warehouse.Namespace - sf.DB, err = Connect(sf.getConnectionCredentials(OptionalCredsT{})) + sf.DB, err = Connect(sf.getConnectionCredentials(optionalCreds{})) if err != nil { return } @@ -878,36 +891,33 @@ func (sf *HandleT) IsEmpty(warehouse warehouseutils.Warehouse) (empty bool, err return } -type OptionalCredsT struct { - schemaName string -} - -func (sf *HandleT) getConnectionCredentials(opts OptionalCredsT) SnowflakeCredentialsT { - return SnowflakeCredentialsT{ - Account: warehouseutils.GetConfigValue(SFAccount, sf.Warehouse), - WHName: warehouseutils.GetConfigValue(SFWarehouse, sf.Warehouse), - DBName: warehouseutils.GetConfigValue(SFDbName, sf.Warehouse), - Username: warehouseutils.GetConfigValue(SFUserName, sf.Warehouse), - Password: warehouseutils.GetConfigValue(SFPassword, sf.Warehouse), +func (sf *Snowflake) getConnectionCredentials(opts optionalCreds) Credentials { + return Credentials{ + Account: warehouseutils.GetConfigValue(Account, sf.Warehouse), + Warehouse: warehouseutils.GetConfigValue(Warehouse, sf.Warehouse), + Database: warehouseutils.GetConfigValue(Database, sf.Warehouse), + User: warehouseutils.GetConfigValue(User, sf.Warehouse), + Role: warehouseutils.GetConfigValue(Role, sf.Warehouse), + Password: warehouseutils.GetConfigValue(Password, sf.Warehouse), schemaName: opts.schemaName, timeout: sf.ConnectTimeout, } } -func (sf *HandleT) Setup(warehouse warehouseutils.Warehouse, uploader warehouseutils.UploaderI) (err error) { +func (sf *Snowflake) Setup(warehouse warehouseutils.Warehouse, uploader warehouseutils.UploaderI) (err error) { sf.Warehouse = warehouse sf.Namespace = warehouse.Namespace sf.CloudProvider = warehouseutils.SnowflakeCloudProvider(warehouse.Destination.Config) sf.Uploader = uploader sf.ObjectStorage = warehouseutils.ObjectStorageType(warehouseutils.SNOWFLAKE, warehouse.Destination.Config, sf.Uploader.UseRudderStorage()) - sf.DB, err = Connect(sf.getConnectionCredentials(OptionalCredsT{})) + sf.DB, err = Connect(sf.getConnectionCredentials(optionalCreds{})) return err } -func (sf *HandleT) TestConnection(warehouse warehouseutils.Warehouse) (err error) { +func (sf *Snowflake) TestConnection(warehouse warehouseutils.Warehouse) (err error) { sf.Warehouse = warehouse - sf.DB, err = Connect(sf.getConnectionCredentials(OptionalCredsT{})) + sf.DB, err = Connect(sf.getConnectionCredentials(optionalCreds{})) if err != nil { return } @@ -928,10 +938,10 @@ func (sf *HandleT) TestConnection(warehouse warehouseutils.Warehouse) (err error } // FetchSchema queries snowflake and returns the schema associated with provided namespace -func (sf *HandleT) FetchSchema(warehouse warehouseutils.Warehouse) (schema, unrecognizedSchema warehouseutils.SchemaT, err error) { +func (sf *Snowflake) FetchSchema(warehouse warehouseutils.Warehouse) (schema, unrecognizedSchema warehouseutils.SchemaT, err error) { sf.Warehouse = warehouse sf.Namespace = warehouse.Namespace - dbHandle, err := Connect(sf.getConnectionCredentials(OptionalCredsT{})) + dbHandle, err := Connect(sf.getConnectionCredentials(optionalCreds{})) if err != nil { return } @@ -955,11 +965,11 @@ func (sf *HandleT) FetchSchema(warehouse warehouseutils.Warehouse) (schema, unre rows, err := dbHandle.Query(sqlStatement) if err != nil && err != sql.ErrNoRows { - pkgLogger.Errorf("SF: Error in fetching schema from snowflake destination:%v, query: %v", sf.Warehouse.Destination.ID, sqlStatement) + sf.Logger.Errorf("SF: Error in fetching schema from snowflake destination:%v, query: %v", sf.Warehouse.Destination.ID, sqlStatement) return } if err == sql.ErrNoRows { - pkgLogger.Infof("SF: No rows, while fetching schema from destination:%v, query: %v", sf.Warehouse.Identifier, sqlStatement) + sf.Logger.Infof("SF: No rows, while fetching schema from destination:%v, query: %v", sf.Warehouse.Identifier, sqlStatement) return schema, unrecognizedSchema, nil } defer rows.Close() @@ -967,7 +977,7 @@ func (sf *HandleT) FetchSchema(warehouse warehouseutils.Warehouse) (schema, unre var tName, cName, cType string err = rows.Scan(&tName, &cName, &cType) if err != nil { - pkgLogger.Errorf("SF: Error in processing fetched schema from snowflake destination:%v", sf.Warehouse.Destination.ID) + sf.Logger.Errorf("SF: Error in processing fetched schema from snowflake destination:%v", sf.Warehouse.Destination.ID) return } if _, ok := schema[tName]; !ok { @@ -987,22 +997,22 @@ func (sf *HandleT) FetchSchema(warehouse warehouseutils.Warehouse) (schema, unre return } -func (sf *HandleT) Cleanup() { +func (sf *Snowflake) Cleanup() { if sf.DB != nil { sf.DB.Close() } } -func (sf *HandleT) LoadUserTables() map[string]error { +func (sf *Snowflake) LoadUserTables() map[string]error { return sf.loadUserTables() } -func (sf *HandleT) LoadTable(tableName string) error { +func (sf *Snowflake) LoadTable(tableName string) error { _, err := sf.loadTable(tableName, sf.Uploader.GetTableSchemaInUpload(tableName), nil, false) return err } -func (sf *HandleT) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { +func (sf *Snowflake) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { var ( total int64 err error @@ -1018,7 +1028,7 @@ func (sf *HandleT) GetTotalCountInTable(ctx context.Context, tableName string) ( return total, err } -func (sf *HandleT) Connect(warehouse warehouseutils.Warehouse) (client.Client, error) { +func (sf *Snowflake) Connect(warehouse warehouseutils.Warehouse) (client.Client, error) { sf.Warehouse = warehouse sf.Namespace = warehouse.Namespace sf.CloudProvider = warehouseutils.SnowflakeCloudProvider(warehouse.Destination.Config) @@ -1027,7 +1037,7 @@ func (sf *HandleT) Connect(warehouse warehouseutils.Warehouse) (client.Client, e warehouse.Destination.Config, misc.IsConfiguredToUseRudderObjectStorage(sf.Warehouse.Destination.Config), ) - dbHandle, err := Connect(sf.getConnectionCredentials(OptionalCredsT{})) + dbHandle, err := Connect(sf.getConnectionCredentials(optionalCreds{})) if err != nil { return client.Client{}, err } @@ -1035,7 +1045,7 @@ func (sf *HandleT) Connect(warehouse warehouseutils.Warehouse) (client.Client, e return client.Client{Type: client.SQLClient, SQL: dbHandle}, err } -func (sf *HandleT) LoadTestTable(location, tableName string, _ map[string]interface{}, _ string) (err error) { +func (sf *Snowflake) LoadTestTable(location, tableName string, _ map[string]interface{}, _ string) (err error) { loadFolder := warehouseutils.GetObjectFolder(sf.ObjectStorage, location) schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`COPY INTO %v(%v) FROM '%v' %s PATTERN = '.*\.csv\.gz' @@ -1050,10 +1060,10 @@ func (sf *HandleT) LoadTestTable(location, tableName string, _ map[string]interf return } -func (sf *HandleT) SetConnectionTimeout(timeout time.Duration) { +func (sf *Snowflake) SetConnectionTimeout(timeout time.Duration) { sf.ConnectTimeout = timeout } -func (sf *HandleT) ErrorMappings() []model.JobError { +func (sf *Snowflake) ErrorMappings() []model.JobError { return errorsMappings } diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index 4c2dabef9a..fae70b1614 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -63,7 +63,7 @@ func TestIntegrationSnowflake(t *testing.T) { }{ { name: "Upload Job with Normal Database", - dbName: credentials.DBName, + dbName: credentials.Database, schema: schema, tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, writeKey: "2eSJyYtqwcFiUILzXv2fcNIrWO7", @@ -78,7 +78,7 @@ func TestIntegrationSnowflake(t *testing.T) { }, { name: "Upload Job with Case Sensitive Database", - dbName: strings.ToLower(credentials.DBName), + dbName: strings.ToLower(credentials.Database), schema: caseSensitiveSchema, tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, writeKey: "2eSJyYtqwcFYUILzXv2fcNIrWO7", @@ -93,7 +93,7 @@ func TestIntegrationSnowflake(t *testing.T) { }, { name: "Async Job with Sources", - dbName: credentials.DBName, + dbName: credentials.Database, schema: sourcesSchema, tables: []string{"tracks", "google_sheet"}, writeKey: "2eSJyYtqwcFYerwzXv2fcNIrWO7", @@ -120,7 +120,7 @@ func TestIntegrationSnowflake(t *testing.T) { t.Parallel() credentialsCopy := credentials - credentialsCopy.DBName = tc.dbName + credentialsCopy.Database = tc.dbName db, err := snowflake.Connect(credentialsCopy) require.NoError(t, err) diff --git a/warehouse/integrations/testhelper/setup.go b/warehouse/integrations/testhelper/setup.go index c82ab85744..3a695a0362 100644 --- a/warehouse/integrations/testhelper/setup.go +++ b/warehouse/integrations/testhelper/setup.go @@ -981,7 +981,7 @@ func credentialsFromKey(key string) (credentials map[string]string) { return } -func SnowflakeCredentials() (credentials snowflake.SnowflakeCredentialsT, err error) { +func SnowflakeCredentials() (credentials snowflake.Credentials, err error) { cred, exists := os.LookupEnv(SnowflakeIntegrationTestCredentials) if !exists { err = fmt.Errorf("following %s does not exists while running the Snowflake test", SnowflakeIntegrationTestCredentials) From 4b10c0b13412ce014124d04676cf8e038d441c7d Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Thu, 23 Feb 2023 12:54:15 +0530 Subject: [PATCH 2/7] enable snowflake integration tests --- warehouse/integrations/snowflake/snowflake_test.go | 2 -- warehouse/integrations/testhelper/setup.go | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index fae70b1614..e18c9bf402 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -29,7 +29,6 @@ func TestIntegrationSnowflake(t *testing.T) { t.Skipf("Skipping %s as %s is not set", t.Name(), testhelper.SnowflakeIntegrationTestCredentials) } - t.SkipNow() t.Parallel() snowflake.Init() @@ -179,7 +178,6 @@ func TestConfigurationValidationSnowflake(t *testing.T) { t.Skipf("Skipping %s as %s is not set", t.Name(), testhelper.SnowflakeIntegrationTestCredentials) } - t.SkipNow() t.Parallel() misc.Init() diff --git a/warehouse/integrations/testhelper/setup.go b/warehouse/integrations/testhelper/setup.go index 3a695a0362..67d9b5cf69 100644 --- a/warehouse/integrations/testhelper/setup.go +++ b/warehouse/integrations/testhelper/setup.go @@ -968,7 +968,7 @@ func credentialsFromKey(key string) (credentials map[string]string) { log.Print(fmt.Errorf("while setting up the workspace config: env %s does not exists", key)) return } - if len(cred) == 0 { + if cred == "" { log.Print(fmt.Errorf("while setting up the workspace config: env %s is empty", key)) return } From 03e2adc0d6c228948c464cfe7df1efcf1344a5e7 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Thu, 23 Feb 2023 13:37:20 +0530 Subject: [PATCH 3/7] deepsource changes --- warehouse/integrations/snowflake/snowflake.go | 30 +++++++++---------- .../integrations/snowflake/snowflake_test.go | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index bdf509eec2..03d26c9470 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -214,14 +214,14 @@ func ColumnsWithDataTypes(columns map[string]string, prefix string) string { // schemaIdentifier returns [DATABASE_NAME].[NAMESPACE] format to access the schema directly. func (sf *Snowflake) schemaIdentifier() string { - return fmt.Sprintf(`"%s"`, + return fmt.Sprintf(`%q`, sf.Namespace, ) } func (sf *Snowflake) createTable(tableName string, columns map[string]string) (err error) { schemaIdentifier := sf.schemaIdentifier() - sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s."%s" ( %v )`, schemaIdentifier, tableName, ColumnsWithDataTypes(columns, "")) + sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.%q ( %v )`, schemaIdentifier, tableName, ColumnsWithDataTypes(columns, "")) sf.Logger.Infof("Creating table in snowflake for SF:%s : %v", sf.Warehouse.Destination.ID, sqlStatement) _, err = sf.DB.Exec(sqlStatement) return @@ -366,7 +366,7 @@ func (sf *Snowflake) loadTable(tableName string, tableSchemaInUpload warehouseut // Truncating the columns by default to avoid size limitation errors // https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions sqlStatement = fmt.Sprintf(`COPY INTO %v(%v) FROM '%v' %s PATTERN = '.*\.csv\.gz' - FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s."%s"`, schemaIdentifier, stagingTableName), sortedColumnNames, loadFolder, sf.authString()) + FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s.%q`, schemaIdentifier, stagingTableName), sortedColumnNames, loadFolder, sf.authString()) sanitisedSQLStmt, regexErr := misc.ReplaceMultiRegex(sqlStatement, map[string]string{ "AWS_KEY_ID='[^']*'": "AWS_KEY_ID='***'", @@ -394,7 +394,7 @@ func (sf *Snowflake) loadTable(tableName string, tableSchemaInUpload warehouseut } stagingColumnNames := warehouseutils.JoinWithFormatting(strKeys, func(_ int, name string) string { - return fmt.Sprintf(`staging."%s"`, name) + return fmt.Sprintf(`staging.%q`, name) }, ",") columnsWithValues := warehouseutils.JoinWithFormatting(strKeys, func(_ int, name string) string { return fmt.Sprintf(`original."%[1]s" = staging."%[1]s"`, name) @@ -462,7 +462,7 @@ func (sf *Snowflake) LoadIdentityMergeRulesTable() (err error) { loadLocation := warehouseutils.GetObjectLocation(sf.ObjectStorage, loadFile.Location) schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`COPY INTO %v(%v) FROM '%v' %s PATTERN = '.*\.csv\.gz' - FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s."%s"`, schemaIdentifier, identityMergeRulesTable), sortedColumnNames, loadLocation, sf.authString()) + FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s.%q`, schemaIdentifier, identityMergeRulesTable), sortedColumnNames, loadLocation, sf.authString()) sanitisedSQLStmt, regexErr := misc.ReplaceMultiRegex(sqlStatement, map[string]string{ "AWS_KEY_ID='[^']*'": "AWS_KEY_ID='***'", @@ -509,7 +509,7 @@ func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { return } - sqlStatement = fmt.Sprintf(`ALTER TABLE %s."%s" ADD COLUMN "ID" int AUTOINCREMENT start 1 increment 1`, schemaIdentifier, stagingTableName) + sqlStatement = fmt.Sprintf(`ALTER TABLE %s.%q ADD COLUMN "ID" int AUTOINCREMENT start 1 increment 1`, schemaIdentifier, stagingTableName) sf.Logger.Infof("SF: Adding autoincrement column for table:%s at %s\n", stagingTableName, sqlStatement) _, err = dbHandle.Exec(sqlStatement) if err != nil && !checkAndIgnoreAlreadyExistError(err) { @@ -519,7 +519,7 @@ func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { loadLocation := warehouseutils.GetObjectLocation(sf.ObjectStorage, loadFile.Location) sqlStatement = fmt.Sprintf(`COPY INTO %v("MERGE_PROPERTY_TYPE", "MERGE_PROPERTY_VALUE", "RUDDER_ID", "UPDATED_AT") FROM '%v' %s PATTERN = '.*\.csv\.gz' - FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s."%s"`, schemaIdentifier, stagingTableName), loadLocation, sf.authString()) + FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s.%q`, schemaIdentifier, stagingTableName), loadLocation, sf.authString()) sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMappingsTable, sqlStatement) _, err = dbHandle.Exec(sqlStatement) @@ -575,12 +575,12 @@ func (sf *Snowflake) loadUserTables() (errorMap map[string]error) { if colName == "ID" { continue } - userColNames = append(userColNames, fmt.Sprintf(`"%s"`, colName)) + userColNames = append(userColNames, fmt.Sprintf(`%q`, colName)) if _, ok := identifyColMap[colName]; ok { - identifyColNames = append(identifyColNames, fmt.Sprintf(`"%s"`, colName)) + identifyColNames = append(identifyColNames, fmt.Sprintf(`%q`, colName)) } else { // This is to handle cases when column in users table not present in identities table - identifyColNames = append(identifyColNames, fmt.Sprintf(`NULL as "%s"`, colName)) + identifyColNames = append(identifyColNames, fmt.Sprintf(`NULL as %q`, colName)) } firstValProps = append(firstValProps, fmt.Sprintf(`FIRST_VALUE("%[1]s" IGNORE NULLS) OVER (PARTITION BY ID ORDER BY RECEIVED_AT DESC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "%[1]s"`, colName)) } @@ -766,7 +766,7 @@ func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error } schemaIdentifier := sf.schemaIdentifier() - sqlStatement := fmt.Sprintf(`SELECT count(*) FROM %s."%s"`, schemaIdentifier, tableName) + sqlStatement := fmt.Sprintf(`SELECT count(*) FROM %s.%q`, schemaIdentifier, tableName) var totalRows int64 err = sf.DB.QueryRow(sqlStatement).Scan(&totalRows) if err != nil { @@ -799,7 +799,7 @@ func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error var offset int64 for { // TODO: Handle case for missing anonymous_id, user_id columns - sqlStatement = fmt.Sprintf(`SELECT DISTINCT %s FROM %s."%s" LIMIT %d OFFSET %d`, toSelectFields, schemaIdentifier, tableName, batchSize, offset) + sqlStatement = fmt.Sprintf(`SELECT DISTINCT %s FROM %s.%q LIMIT %d OFFSET %d`, toSelectFields, schemaIdentifier, tableName, batchSize, offset) sf.Logger.Infof("SF: Downloading distinct combinations of anonymous_id, user_id: %s, totalRows: %d", sqlStatement, totalRows) var rows *sql.Rows rows, err = sf.DB.Query(sqlStatement) @@ -877,7 +877,7 @@ func (sf *Snowflake) IsEmpty(warehouse warehouseutils.Warehouse) (empty bool, er continue } schemaIdentifier := sf.schemaIdentifier() - sqlStatement := fmt.Sprintf(`SELECT COUNT(*) FROM %s."%s"`, schemaIdentifier, tableName) + sqlStatement := fmt.Sprintf(`SELECT COUNT(*) FROM %s.%q`, schemaIdentifier, tableName) var count int64 err = sf.DB.QueryRow(sqlStatement).Scan(&count) if err != nil { @@ -1050,8 +1050,8 @@ func (sf *Snowflake) LoadTestTable(location, tableName string, _ map[string]inte schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`COPY INTO %v(%v) FROM '%v' %s PATTERN = '.*\.csv\.gz' FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, - fmt.Sprintf(`%s."%s"`, schemaIdentifier, tableName), - fmt.Sprintf(`"%s", "%s"`, "id", "val"), + fmt.Sprintf(`%s.%q`, schemaIdentifier, tableName), + fmt.Sprintf(`%q, %q`, "id", "val"), loadFolder, sf.authString(), ) diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index e18c9bf402..aabd375da4 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -128,7 +128,7 @@ func TestIntegrationSnowflake(t *testing.T) { require.NoError( t, testhelper.WithConstantBackoff(func() (err error) { - _, err = db.Exec(fmt.Sprintf(`DROP SCHEMA "%s" CASCADE;`, tc.schema)) + _, err = db.Exec(fmt.Sprintf(`DROP SCHEMA %q CASCADE;`, tc.schema)) return }), fmt.Sprintf("Failed dropping schema %s for Snowflake", tc.schema), From 87316ed591357e6674b25c4867c8c9c644c8feb7 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Fri, 24 Feb 2023 14:49:01 +0530 Subject: [PATCH 4/7] add integration test for snowflake role --- .../integrations/snowflake/snowflake_test.go | 29 +++- .../testdata/workspaceConfig/template.json | 137 +++++++++++++++++- warehouse/integrations/testhelper/setup.go | 3 +- 3 files changed, 157 insertions(+), 12 deletions(-) diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index aabd375da4..569745eef0 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -40,13 +40,15 @@ func TestIntegrationSnowflake(t *testing.T) { provider = warehouseutils.SNOWFLAKE jobsDB = testhelper.SetUpJobsDB(t) schema = testhelper.Schema(provider, testhelper.SnowflakeIntegrationTestSchema) + roleSchema = fmt.Sprintf("%s_%s", schema, "ROLE") sourcesSchema = fmt.Sprintf("%s_%s", schema, "SOURCES") caseSensitiveSchema = fmt.Sprintf("%s_%s", schema, "CS") ) testcase := []struct { name string - dbName string + database string + role string schema string writeKey string sourceID string @@ -62,7 +64,7 @@ func TestIntegrationSnowflake(t *testing.T) { }{ { name: "Upload Job with Normal Database", - dbName: credentials.Database, + database: credentials.Database, schema: schema, tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, writeKey: "2eSJyYtqwcFiUILzXv2fcNIrWO7", @@ -75,9 +77,25 @@ func TestIntegrationSnowflake(t *testing.T) { "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) }, }, + { + name: "Upload Job with Role", + database: credentials.Database, + role: credentials.Role, + schema: roleSchema, + tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, + writeKey: "2eSafstqwcFYUILzXv2fcNIrWO7", + sourceID: "24p1HhPsafaFBMKuzvx7GshCLKR", + destinationID: "24qeADObsdsJhijDnEppO6P1SNc", + stagingFilesEventsMap: testhelper.EventsCountMap{ + "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) + }, + stagingFilesModifiedEventsMap: testhelper.EventsCountMap{ + "wh_staging_files": 34, // 32 + 2 (merge events because of ID resolution) + }, + }, { name: "Upload Job with Case Sensitive Database", - dbName: strings.ToLower(credentials.Database), + database: strings.ToLower(credentials.Database), schema: caseSensitiveSchema, tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, writeKey: "2eSJyYtqwcFYUILzXv2fcNIrWO7", @@ -92,7 +110,7 @@ func TestIntegrationSnowflake(t *testing.T) { }, { name: "Async Job with Sources", - dbName: credentials.Database, + database: credentials.Database, schema: sourcesSchema, tables: []string{"tracks", "google_sheet"}, writeKey: "2eSJyYtqwcFYerwzXv2fcNIrWO7", @@ -119,7 +137,8 @@ func TestIntegrationSnowflake(t *testing.T) { t.Parallel() credentialsCopy := credentials - credentialsCopy.Database = tc.dbName + credentialsCopy.Database = tc.database + credentialsCopy.Role = tc.role db, err := snowflake.Connect(credentialsCopy) require.NoError(t, err) diff --git a/warehouse/integrations/testdata/workspaceConfig/template.json b/warehouse/integrations/testdata/workspaceConfig/template.json index 647d9a1618..6c40b1221d 100644 --- a/warehouse/integrations/testdata/workspaceConfig/template.json +++ b/warehouse/integrations/testdata/workspaceConfig/template.json @@ -752,9 +752,9 @@ { "config": { "account": "{{.snowflakeAccount}}", - "database": "{{.snowflakeDBName}}", - "warehouse": "{{.snowflakeWHName}}", - "user": "{{.snowflakeUsername}}", + "database": "{{.snowflakeDatabase}}", + "warehouse": "{{.snowflakeWarehouse}}", + "user": "{{.snowflakeUser}}", "password": "{{.snowflakePassword}}", "cloudProvider": "AWS", "bucketName": "{{.snowflakeBucketName}}", @@ -876,9 +876,9 @@ { "config": { "account": "{{.snowflakeAccount}}", - "database": "{{.snowflakeCaseSensitiveDBName}}", - "warehouse": "{{.snowflakeWHName}}", - "user": "{{.snowflakeUsername}}", + "database": "{{.snowflakeCaseSensitiveDatabase}}", + "warehouse": "{{.snowflakeWarehouse}}", + "user": "{{.snowflakeUser}}", "password": "{{.snowflakePassword}}", "cloudProvider": "AWS", "bucketName": "{{.snowflakeBucketName}}", @@ -976,6 +976,131 @@ }, "dgSourceTrackingPlanConfig": null }, + { + "config": { + "isSampleSource": true, + "eventUpload": false, + "eventUploadTS": 1646073666353 + }, + "liveEventsConfig": { + "eventUpload": false, + "eventUploadTS": 1646073666353 + }, + "id": "24p1HhPsafaFBMKuzvx7GshCLKR", + "name": "snowflake-role-wh-integration", + "writeKey": "{{.snowflakeRoleWriteKey}}", + "enabled": true, + "sourceDefinitionId": "1dCzCUAtpWDzNxgGUYzq9sZdZZB", + "createdBy": "24p1CMAkx18KwNbFDXlR7sUhqaa", + "workspaceId": "{{.workspaceId}}", + "deleted": false, + "createdAt": "2022-02-08T09:30:27.073Z", + "updatedAt": "2022-02-28T18:41:06.362Z", + "destinations": [ + { + "config": { + "account": "{{.snowflakeAccount}}", + "database": "{{.snowflakeDatabase}}", + "warehouse": "{{.snowflakeWarehouse}}", + "user": "{{.snowflakeUser}}", + "role": "{{.snowflakeRole}}", + "password": "{{.snowflakePassword}}", + "cloudProvider": "AWS", + "bucketName": "{{.snowflakeBucketName}}", + "storageIntegration": "", + "accessKeyID": "{{.snowflakeAccessKeyID}}", + "accessKey": "{{.snowflakeAccessKey}}", + "namespace": "{{.snowflakeRoleNamespace}}", + "prefix": "snowflake-prefix", + "syncFrequency": "30", + "enableSSE": false, + "useRudderStorage": false + }, + "liveEventsConfig": {}, + "secretConfig": {}, + "id": "24qeADObsdsJhijDnEppO6P1SNc", + "name": "snowflake-role-demo", + "enabled": true, + "workspaceId": "{{.workspaceId}}", + "deleted": false, + "createdAt": "2022-02-08T23:19:58.278Z", + "updatedAt": "2022-05-17T08:18:33.587Z", + "revisionId": "29HgdgvsdsqFDTUESgmIZ3YSehV", + "transformations": [], + "destinationDefinition": { + "config": { + "destConfig": { + "defaultConfig": [ + "account", + "database", + "warehouse", + "user", + "password", + "cloudProvider", + "bucketName", + "containerName", + "storageIntegration", + "accessKeyID", + "accessKey", + "accountName", + "accountKey", + "credentials", + "namespace", + "prefix", + "syncFrequency", + "syncStartAt", + "enableSSE", + "excludeWindow", + "useRudderStorage" + ] + }, + "secretKeys": [ + "password", + "accessKeyID", + "accessKey" + ], + "excludeKeys": [], + "includeKeys": [], + "transformAt": "processor", + "transformAtV1": "processor", + "supportedSourceTypes": [ + "android", + "ios", + "web", + "unity", + "amp", + "cloud", + "reactnative", + "cloudSource", + "flutter", + "cordova" + ], + "saveDestinationResponse": true + }, + "responseRules": null, + "options": null, + "id": "1XjvXnzw34KjA1YOuKqL1kwzh6", + "name": "SNOWFLAKE", + "displayName": "Snowflake", + "category": "warehouse", + "createdAt": "2020-02-13T05:39:20.184Z", + "updatedAt": "2022-02-08T06:46:45.432Z" + }, + "isConnectionEnabled": true, + "isProcessorEnabled": true + } + ], + "sourceDefinition": { + "options": null, + "id": "1dCzCUAtpWDzNxgGUYzq9sZdZZB", + "name": "HTTP", + "displayName": "HTTP", + "category": "", + "createdAt": "2020-06-12T06:35:35.962Z", + "updatedAt": "2020-06-12T06:35:35.962Z" + }, + "dgSourceTrackingPlanConfig": null + }, { "config": { "isSampleSource": true, diff --git a/warehouse/integrations/testhelper/setup.go b/warehouse/integrations/testhelper/setup.go index 67d9b5cf69..648b4075ec 100644 --- a/warehouse/integrations/testhelper/setup.go +++ b/warehouse/integrations/testhelper/setup.go @@ -857,6 +857,7 @@ func PopulateTemplateConfigurations() map[string]string { "bigqueryWriteKey": "J77aX7tLFJ84qYU6UrN8ctecwZt", "snowflakeWriteKey": "2eSJyYtqwcFiUILzXv2fcNIrWO7", "snowflakeCaseSensitiveWriteKey": "2eSJyYtqwcFYUILzXv2fcNIrWO7", + "snowflakeRoleWriteKey": "2eSafstqwcFYUILzXv2fcNIrWO7", "redshiftWriteKey": "JAAwdCxmM8BIabKERsUhPNmMmdf", "deltalakeWriteKey": "sToFgoilA0U1WxNeW1gdgUVDsEW", @@ -899,7 +900,7 @@ func enhanceWithSnowflakeConfigurations(values map[string]string) { values[fmt.Sprintf("snowflake%s", k)] = v } - values["snowflakeCaseSensitiveDBName"] = strings.ToLower(values["snowflakeDBName"]) + values["snowflakeCaseSensitiveDatabase"] = strings.ToLower(values["snowflakeDatabase"]) values["snowflakeNamespace"] = Schema(warehouseutils.SNOWFLAKE, SnowflakeIntegrationTestSchema) values["snowflakeCaseSensitiveNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "CS") values["snowflakeSourcesNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "sources") From 7770427729d3f22baefc9b81962801034379b32f Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Fri, 24 Feb 2023 15:03:39 +0530 Subject: [PATCH 5/7] added snowflake role namespace --- warehouse/integrations/testhelper/setup.go | 1 + 1 file changed, 1 insertion(+) diff --git a/warehouse/integrations/testhelper/setup.go b/warehouse/integrations/testhelper/setup.go index 648b4075ec..8bb07b2585 100644 --- a/warehouse/integrations/testhelper/setup.go +++ b/warehouse/integrations/testhelper/setup.go @@ -902,6 +902,7 @@ func enhanceWithSnowflakeConfigurations(values map[string]string) { values["snowflakeCaseSensitiveDatabase"] = strings.ToLower(values["snowflakeDatabase"]) values["snowflakeNamespace"] = Schema(warehouseutils.SNOWFLAKE, SnowflakeIntegrationTestSchema) + values["snowflakeRoleNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "ROLE") values["snowflakeCaseSensitiveNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "CS") values["snowflakeSourcesNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "sources") } From 5adb5c04ea469a33483a587a7c9578e523cb969f Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Sat, 25 Feb 2023 03:42:56 +0530 Subject: [PATCH 6/7] added snowflake role integration tests --- .github/workflows/tests.yaml | 77 ++++++++++--------- .../integrations/snowflake/snowflake_test.go | 27 ++++--- .../testdata/workspaceConfig/template.json | 28 +++---- warehouse/integrations/testhelper/setup.go | 27 ++++--- 4 files changed, 88 insertions(+), 71 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 83d60c80d7..c70740fffd 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,29 +12,29 @@ jobs: runs-on: 'ubuntu-20.04' strategy: matrix: - FEATURES: [oss ,enterprise] + FEATURES: [ oss ,enterprise ] steps: - - name: Checkout - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 - with: - go-version: '~1.20.1' - check-latest: true - cache: true - - run: go version - - run: go mod download # Not required, used to segregate module download vs test times + - name: Checkout + uses: actions/checkout@v3 + - uses: actions/setup-go@v3 + with: + go-version: '~1.20.1' + check-latest: true + cache: true + - run: go version + - run: go mod download # Not required, used to segregate module download vs test times - - name: Integration test for enterprise - if: matrix.FEATURES == 'enterprise' - run: go test -v ./integration_test/docker_test/docker_test.go -count 1 - env: - ENTERPRISE_TOKEN: ${{ secrets.ENTERPRISE_TOKEN }} + - name: Integration test for enterprise + if: matrix.FEATURES == 'enterprise' + run: go test -v ./integration_test/docker_test/docker_test.go -count 1 + env: + ENTERPRISE_TOKEN: ${{ secrets.ENTERPRISE_TOKEN }} - - name: Integration test for oss - if: matrix.FEATURES == 'oss' - run: go test -v ./integration_test/docker_test/docker_test.go -count 1 - env: - RSERVER_ENABLE_MULTITENANCY: ${{ matrix.MULTITENANCY }} + - name: Integration test for oss + if: matrix.FEATURES == 'oss' + run: go test -v ./integration_test/docker_test/docker_test.go -count 1 + env: + RSERVER_ENABLE_MULTITENANCY: ${{ matrix.MULTITENANCY }} warehouse-integration: name: Warehouse Service Integration @@ -59,28 +59,29 @@ jobs: BIGQUERY_INTEGRATION_TEST_CREDENTIALS: ${{ secrets.BIGQUERY_INTEGRATION_TEST_CREDENTIALS }} REDSHIFT_INTEGRATION_TEST_CREDENTIALS: ${{ secrets.REDSHIFT_INTEGRATION_TEST_CREDENTIALS }} SNOWFLAKE_INTEGRATION_TEST_CREDENTIALS: ${{ secrets.SNOWFLAKE_INTEGRATION_TEST_CREDENTIALS }} + SNOWFLAKE_RBAC_INTEGRATION_TEST_CREDENTIALS: ${{ secrets.SNOWFLAKE_RBAC_INTEGRATION_TEST_CREDENTIALS }} DATABRICKS_INTEGRATION_TEST_CREDENTIALS: ${{ secrets.DATABRICKS_INTEGRATION_TEST_CREDENTIALS }} unit: name: Unit runs-on: 'ubuntu-20.04' steps: - - uses: actions/checkout@v2 - - uses: actions/setup-go@v3 - with: - go-version: '~1.20.1' - check-latest: true - cache: true + - uses: actions/checkout@v2 + - uses: actions/setup-go@v3 + with: + go-version: '~1.20.1' + check-latest: true + cache: true - - run: go version - - run: go mod download # Not required, used to segregate module download vs test times - - env: - TEST_KAFKA_CONFLUENT_CLOUD_HOST: ${{ secrets.TEST_KAFKA_CONFLUENT_CLOUD_HOST }} - TEST_KAFKA_CONFLUENT_CLOUD_KEY: ${{ secrets.TEST_KAFKA_CONFLUENT_CLOUD_KEY }} - TEST_KAFKA_CONFLUENT_CLOUD_SECRET: ${{ secrets.TEST_KAFKA_CONFLUENT_CLOUD_SECRET }} - TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_HOST: ${{ secrets.TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_HOST }} - TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_EVENTHUB_NAME: ${{ secrets.TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_EVENTHUB_NAME }} - TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_CONNECTION_STRING: ${{ secrets.TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_CONNECTION_STRING }} - TEST_S3_DATALAKE_CREDENTIALS: ${{ secrets.TEST_S3_DATALAKE_CREDENTIALS }} - run: make test - - uses: codecov/codecov-action@v2 + - run: go version + - run: go mod download # Not required, used to segregate module download vs test times + - env: + TEST_KAFKA_CONFLUENT_CLOUD_HOST: ${{ secrets.TEST_KAFKA_CONFLUENT_CLOUD_HOST }} + TEST_KAFKA_CONFLUENT_CLOUD_KEY: ${{ secrets.TEST_KAFKA_CONFLUENT_CLOUD_KEY }} + TEST_KAFKA_CONFLUENT_CLOUD_SECRET: ${{ secrets.TEST_KAFKA_CONFLUENT_CLOUD_SECRET }} + TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_HOST: ${{ secrets.TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_HOST }} + TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_EVENTHUB_NAME: ${{ secrets.TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_EVENTHUB_NAME }} + TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_CONNECTION_STRING: ${{ secrets.TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_CONNECTION_STRING }} + TEST_S3_DATALAKE_CREDENTIALS: ${{ secrets.TEST_S3_DATALAKE_CREDENTIALS }} + run: make test + - uses: codecov/codecov-action@v2 diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index 569745eef0..a4cd0c7991 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -28,12 +28,18 @@ func TestIntegrationSnowflake(t *testing.T) { if _, exists := os.LookupEnv(testhelper.SnowflakeIntegrationTestCredentials); !exists { t.Skipf("Skipping %s as %s is not set", t.Name(), testhelper.SnowflakeIntegrationTestCredentials) } + if _, exists := os.LookupEnv(testhelper.SnowflakeRBACIntegrationTestCredentials); !exists { + t.Skipf("Skipping %s as %s is not set", t.Name(), testhelper.SnowflakeRBACIntegrationTestCredentials) + } t.Parallel() snowflake.Init() - credentials, err := testhelper.SnowflakeCredentials() + credentials, err := testhelper.SnowflakeCredentials(testhelper.SnowflakeIntegrationTestCredentials) + require.NoError(t, err) + + rbacCrecentials, err := testhelper.SnowflakeCredentials(testhelper.SnowflakeRBACIntegrationTestCredentials) require.NoError(t, err) var ( @@ -43,12 +49,13 @@ func TestIntegrationSnowflake(t *testing.T) { roleSchema = fmt.Sprintf("%s_%s", schema, "ROLE") sourcesSchema = fmt.Sprintf("%s_%s", schema, "SOURCES") caseSensitiveSchema = fmt.Sprintf("%s_%s", schema, "CS") + database = credentials.Database ) testcase := []struct { name string + credentials snowflake.Credentials database string - role string schema string writeKey string sourceID string @@ -64,7 +71,8 @@ func TestIntegrationSnowflake(t *testing.T) { }{ { name: "Upload Job with Normal Database", - database: credentials.Database, + credentials: credentials, + database: database, schema: schema, tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, writeKey: "2eSJyYtqwcFiUILzXv2fcNIrWO7", @@ -79,8 +87,8 @@ func TestIntegrationSnowflake(t *testing.T) { }, { name: "Upload Job with Role", - database: credentials.Database, - role: credentials.Role, + credentials: rbacCrecentials, + database: database, schema: roleSchema, tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, writeKey: "2eSafstqwcFYUILzXv2fcNIrWO7", @@ -95,7 +103,8 @@ func TestIntegrationSnowflake(t *testing.T) { }, { name: "Upload Job with Case Sensitive Database", - database: strings.ToLower(credentials.Database), + credentials: credentials, + database: strings.ToLower(database), schema: caseSensitiveSchema, tables: []string{"identifies", "users", "tracks", "product_track", "pages", "screens", "aliases", "groups"}, writeKey: "2eSJyYtqwcFYUILzXv2fcNIrWO7", @@ -110,7 +119,8 @@ func TestIntegrationSnowflake(t *testing.T) { }, { name: "Async Job with Sources", - database: credentials.Database, + credentials: credentials, + database: database, schema: sourcesSchema, tables: []string{"tracks", "google_sheet"}, writeKey: "2eSJyYtqwcFYerwzXv2fcNIrWO7", @@ -136,9 +146,8 @@ func TestIntegrationSnowflake(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - credentialsCopy := credentials + credentialsCopy := tc.credentials credentialsCopy.Database = tc.database - credentialsCopy.Role = tc.role db, err := snowflake.Connect(credentialsCopy) require.NoError(t, err) diff --git a/warehouse/integrations/testdata/workspaceConfig/template.json b/warehouse/integrations/testdata/workspaceConfig/template.json index 6c40b1221d..ad3d77e98b 100644 --- a/warehouse/integrations/testdata/workspaceConfig/template.json +++ b/warehouse/integrations/testdata/workspaceConfig/template.json @@ -987,8 +987,8 @@ "eventUploadTS": 1646073666353 }, "id": "24p1HhPsafaFBMKuzvx7GshCLKR", - "name": "snowflake-role-wh-integration", - "writeKey": "{{.snowflakeRoleWriteKey}}", + "name": "snowflake-rbac-wh-integration", + "writeKey": "{{.snowflakeRBACWriteKey}}", "enabled": true, "sourceDefinitionId": "1dCzCUAtpWDzNxgGUYzq9sZdZZB", "createdBy": "24p1CMAkx18KwNbFDXlR7sUhqaa", @@ -999,19 +999,19 @@ "destinations": [ { "config": { - "account": "{{.snowflakeAccount}}", - "database": "{{.snowflakeDatabase}}", - "warehouse": "{{.snowflakeWarehouse}}", - "user": "{{.snowflakeUser}}", - "role": "{{.snowflakeRole}}", - "password": "{{.snowflakePassword}}", + "account": "{{.snowflakeRBACAccount}}", + "database": "{{.snowflakeRBACDatabase}}", + "warehouse": "{{.snowflakeRBACWarehouse}}", + "user": "{{.snowflakeRBACUser}}", + "role": "{{.snowflakeRBACRole}}", + "password": "{{.snowflakeRBACPassword}}", "cloudProvider": "AWS", - "bucketName": "{{.snowflakeBucketName}}", + "bucketName": "{{.snowflakeRBACBucketName}}", "storageIntegration": "", - "accessKeyID": "{{.snowflakeAccessKeyID}}", - "accessKey": "{{.snowflakeAccessKey}}", - "namespace": "{{.snowflakeRoleNamespace}}", - "prefix": "snowflake-prefix", + "accessKeyID": "{{.snowflakeRBACAccessKeyID}}", + "accessKey": "{{.snowflakeRBACAccessKey}}", + "namespace": "{{.snowflakeRBACNamespace}}", + "prefix": "snowflake-rbac-prefix", "syncFrequency": "30", "enableSSE": false, "useRudderStorage": false @@ -1019,7 +1019,7 @@ "liveEventsConfig": {}, "secretConfig": {}, "id": "24qeADObsdsJhijDnEppO6P1SNc", - "name": "snowflake-role-demo", + "name": "snowflake-rbac-demo", "enabled": true, "workspaceId": "{{.workspaceId}}", "deleted": false, diff --git a/warehouse/integrations/testhelper/setup.go b/warehouse/integrations/testhelper/setup.go index 8bb07b2585..54b49c28cb 100644 --- a/warehouse/integrations/testhelper/setup.go +++ b/warehouse/integrations/testhelper/setup.go @@ -56,10 +56,11 @@ const ( ) const ( - SnowflakeIntegrationTestCredentials = "SNOWFLAKE_INTEGRATION_TEST_CREDENTIALS" - RedshiftIntegrationTestCredentials = "REDSHIFT_INTEGRATION_TEST_CREDENTIALS" - DeltalakeIntegrationTestCredentials = "DATABRICKS_INTEGRATION_TEST_CREDENTIALS" - BigqueryIntegrationTestCredentials = "BIGQUERY_INTEGRATION_TEST_CREDENTIALS" + SnowflakeIntegrationTestCredentials = "SNOWFLAKE_INTEGRATION_TEST_CREDENTIALS" + SnowflakeRBACIntegrationTestCredentials = "SNOWFLAKE_RBAC_INTEGRATION_TEST_CREDENTIALS" + RedshiftIntegrationTestCredentials = "REDSHIFT_INTEGRATION_TEST_CREDENTIALS" + DeltalakeIntegrationTestCredentials = "DATABRICKS_INTEGRATION_TEST_CREDENTIALS" + BigqueryIntegrationTestCredentials = "BIGQUERY_INTEGRATION_TEST_CREDENTIALS" ) const ( @@ -857,7 +858,7 @@ func PopulateTemplateConfigurations() map[string]string { "bigqueryWriteKey": "J77aX7tLFJ84qYU6UrN8ctecwZt", "snowflakeWriteKey": "2eSJyYtqwcFiUILzXv2fcNIrWO7", "snowflakeCaseSensitiveWriteKey": "2eSJyYtqwcFYUILzXv2fcNIrWO7", - "snowflakeRoleWriteKey": "2eSafstqwcFYUILzXv2fcNIrWO7", + "snowflakeRBACWriteKey": "2eSafstqwcFYUILzXv2fcNIrWO7", "redshiftWriteKey": "JAAwdCxmM8BIabKERsUhPNmMmdf", "deltalakeWriteKey": "sToFgoilA0U1WxNeW1gdgUVDsEW", @@ -895,16 +896,22 @@ func enhanceWithSnowflakeConfigurations(values map[string]string) { if _, exists := os.LookupEnv(SnowflakeIntegrationTestCredentials); !exists { return } + if _, exists := os.LookupEnv(SnowflakeRBACIntegrationTestCredentials); !exists { + return + } for k, v := range credentialsFromKey(SnowflakeIntegrationTestCredentials) { values[fmt.Sprintf("snowflake%s", k)] = v } + for k, v := range credentialsFromKey(SnowflakeRBACIntegrationTestCredentials) { + values[fmt.Sprintf("snowflakeRBAC%s", k)] = v + } values["snowflakeCaseSensitiveDatabase"] = strings.ToLower(values["snowflakeDatabase"]) values["snowflakeNamespace"] = Schema(warehouseutils.SNOWFLAKE, SnowflakeIntegrationTestSchema) - values["snowflakeRoleNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "ROLE") + values["snowflakeRBACNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "ROLE") values["snowflakeCaseSensitiveNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "CS") - values["snowflakeSourcesNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "sources") + values["snowflakeSourcesNamespace"] = fmt.Sprintf("%s_%s", values["snowflakeNamespace"], "SOURCES") } func enhanceWithRedshiftConfigurations(values map[string]string) { @@ -983,10 +990,10 @@ func credentialsFromKey(key string) (credentials map[string]string) { return } -func SnowflakeCredentials() (credentials snowflake.Credentials, err error) { - cred, exists := os.LookupEnv(SnowflakeIntegrationTestCredentials) +func SnowflakeCredentials(env string) (credentials snowflake.Credentials, err error) { + cred, exists := os.LookupEnv(env) if !exists { - err = fmt.Errorf("following %s does not exists while running the Snowflake test", SnowflakeIntegrationTestCredentials) + err = fmt.Errorf("following %s does not exists while running the Snowflake test", env) return } From 32d0199821c0d420f28a61973e6e7c0bfd42fea9 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Sat, 25 Feb 2023 04:06:28 +0530 Subject: [PATCH 7/7] sources template snowflake fix --- .../integrations/testdata/workspaceConfig/template.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/warehouse/integrations/testdata/workspaceConfig/template.json b/warehouse/integrations/testdata/workspaceConfig/template.json index ad3d77e98b..4d33d83a0b 100644 --- a/warehouse/integrations/testdata/workspaceConfig/template.json +++ b/warehouse/integrations/testdata/workspaceConfig/template.json @@ -1997,9 +1997,9 @@ { "config": { "account": "{{.snowflakeAccount}}", - "database": "{{.snowflakeDBName}}", - "warehouse": "{{.snowflakeWHName}}", - "user": "{{.snowflakeUsername}}", + "database": "{{.snowflakeDatabase}}", + "warehouse": "{{.snowflakeWarehouse}}", + "user": "{{.snowflakeUser}}", "password": "{{.snowflakePassword}}", "cloudProvider": "AWS", "bucketName": "{{.snowflakeBucketName}}",