diff --git a/db/queries/queries.go b/db/queries/queries.go index d13e4b2..58ac6f9 100644 --- a/db/queries/queries.go +++ b/db/queries/queries.go @@ -19,24 +19,17 @@ import ( "strings" "text/template" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/pkg/types" ) -type DBTX interface { +type DBQuerier interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row } -// For mocking -type DBQuerier interface { - QueryRow(ctx context.Context, sql string, args ...any) pgx.Row -} - var validIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) func SanitiseIdentifier(ident string) error { @@ -172,7 +165,7 @@ func GeneratePkeyOffsetsQuery( return RenderSQL(SQLTemplates.GetPkeyOffsets, data) } -func CreateXORFunction(ctx context.Context, db *pgxpool.Pool) error { +func CreateXORFunction(ctx context.Context, db DBQuerier) error { sql, err := RenderSQL(SQLTemplates.CreateXORFunction, nil) if err != nil { return fmt.Errorf("failed to render CreateXORFunction SQL: %w", err) @@ -186,7 +179,7 @@ func CreateXORFunction(ctx context.Context, db *pgxpool.Pool) error { return nil } -func CreateMetadataTable(ctx context.Context, db *pgxpool.Pool) error { +func CreateMetadataTable(ctx context.Context, db DBQuerier) error { sql, err := RenderSQL(SQLTemplates.CreateMetadataTable, nil) if err != nil { return fmt.Errorf("failed to render CreateMetadataTable SQL: %w", err) @@ -200,7 +193,7 @@ func CreateMetadataTable(ctx context.Context, db *pgxpool.Pool) error { return nil } -func CreateCDCMetadataTable(ctx context.Context, db *pgxpool.Pool) error { +func CreateCDCMetadataTable(ctx context.Context, db DBQuerier) error { sql, err := RenderSQL(SQLTemplates.CreateCDCMetadataTable, nil) if err != nil { return fmt.Errorf("failed to render CreateCDCMetadataTable SQL: %w", err) @@ -214,7 +207,7 @@ func CreateCDCMetadataTable(ctx context.Context, db *pgxpool.Pool) error { return nil } -func CreateSimpleMtreeTable(ctx context.Context, db *pgxpool.Pool, mtreeTable, pkeyType string) error { +func CreateSimpleMtreeTable(ctx context.Context, db DBQuerier, mtreeTable, pkeyType string) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, "PkeyType": pkeyType, @@ -233,7 +226,7 @@ func CreateSimpleMtreeTable(ctx context.Context, db *pgxpool.Pool, mtreeTable, p return nil } -func DropCompositeType(ctx context.Context, db *pgxpool.Pool, compositeTypeName string) error { +func DropCompositeType(ctx context.Context, db DBQuerier, compositeTypeName string) error { data := map[string]interface{}{ "CompositeTypeName": compositeTypeName, } @@ -251,7 +244,7 @@ func DropCompositeType(ctx context.Context, db *pgxpool.Pool, compositeTypeName return nil } -func CreateCompositeType(ctx context.Context, db *pgxpool.Pool, compositeTypeName, keyTypeColumns string) error { +func CreateCompositeType(ctx context.Context, db DBQuerier, compositeTypeName, keyTypeColumns string) error { data := map[string]interface{}{ "CompositeTypeName": compositeTypeName, "KeyTypeColumns": keyTypeColumns, @@ -270,7 +263,7 @@ func CreateCompositeType(ctx context.Context, db *pgxpool.Pool, compositeTypeNam return nil } -func CreateCompositeMtreeTable(ctx context.Context, db *pgxpool.Pool, mtreeTable, compositeTypeName string) error { +func CreateCompositeMtreeTable(ctx context.Context, db DBQuerier, mtreeTable, compositeTypeName string) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, "CompositeTypeName": compositeTypeName, @@ -289,7 +282,7 @@ func CreateCompositeMtreeTable(ctx context.Context, db *pgxpool.Pool, mtreeTable return nil } -func InsertBlockRanges(ctx context.Context, db *pgxpool.Pool, mtreeTable string, nodePosition int64, rangeStart, rangeEnd interface{}) error { +func InsertBlockRanges(ctx context.Context, db DBQuerier, mtreeTable string, nodePosition int64, rangeStart, rangeEnd interface{}) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -307,27 +300,49 @@ func InsertBlockRanges(ctx context.Context, db *pgxpool.Pool, mtreeTable string, return nil } -func InsertCompositeBlockRanges(ctx context.Context, db *pgxpool.Pool, mtreeTable string, nodePosition int64, startTupleValues, endTupleValues string) error { +func InsertCompositeBlockRanges(ctx context.Context, db DBQuerier, mtreeTable string, nodePosition int64, startVals, endVals []any) error { + startPh := make([]string, len(startVals)) + args := make([]any, 0, 1+len(startVals)+len(endVals)) + args = append(args, nodePosition) + argIdx := 2 + for i := range startVals { + startPh[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, startVals[i]) + argIdx++ + } + + var endExpr string + if endVals == nil { + endExpr = "NULL" + } else { + endPh := make([]string, len(endVals)) + for i := range endVals { + endPh[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, endVals[i]) + argIdx++ + } + endExpr = fmt.Sprintf("ROW(%s)", strings.Join(endPh, ", ")) + } + startExpr := fmt.Sprintf("ROW(%s)", strings.Join(startPh, ", ")) + data := map[string]interface{}{ - "MtreeTable": mtreeTable, - "StartTupleValues": startTupleValues, - "EndTupleValues": endTupleValues, + "MtreeTable": mtreeTable, + "StartExpr": startExpr, + "EndExpr": endExpr, } - sql, err := RenderSQL(SQLTemplates.InsertCompositeBlockRanges, data) + stmt, err := RenderSQL(SQLTemplates.InsertCompositeBlockRanges, data) if err != nil { return fmt.Errorf("failed to render InsertCompositeBlockRanges SQL: %w", err) } - _, err = db.Exec(ctx, sql, nodePosition) - if err != nil { + if _, err := db.Exec(ctx, stmt, args...); err != nil { return fmt.Errorf("query to insert composite block ranges for '%s' failed: %w", mtreeTable, err) } - return nil } -func InsertBlockRangesBatchSimple(ctx context.Context, db *pgxpool.Pool, mtreeTable string, ranges []types.BlockRange) error { +func InsertBlockRangesBatchSimple(ctx context.Context, db DBQuerier, mtreeTable string, ranges []types.BlockRange) error { if len(ranges) == 0 { return nil } @@ -393,7 +408,7 @@ func InsertBlockRangesBatchSimple(ctx context.Context, db *pgxpool.Pool, mtreeTa return nil } -func InsertBlockRangesBatchComposite(ctx context.Context, db *pgxpool.Pool, mtreeTable string, ranges []types.BlockRange, keyLen int) error { +func InsertBlockRangesBatchComposite(ctx context.Context, db DBQuerier, mtreeTable string, ranges []types.BlockRange, keyLen int) error { if len(ranges) == 0 { return nil } @@ -480,7 +495,7 @@ func InsertBlockRangesBatchComposite(ctx context.Context, db *pgxpool.Pool, mtre return nil } -func GetPkeyOffsets(ctx context.Context, db *pgxpool.Pool, schema, table string, keyColumns []string, tableSampleMethod string, samplePercent float64, ntileCount int) ([]types.PkeyOffset, error) { +func GetPkeyOffsets(ctx context.Context, db DBQuerier, schema, table string, keyColumns []string, tableSampleMethod string, samplePercent float64, ntileCount int) ([]types.PkeyOffset, error) { sql, err := GeneratePkeyOffsetsQuery(schema, table, keyColumns, tableSampleMethod, samplePercent, ntileCount) if err != nil { return nil, fmt.Errorf("failed to generate GetPkeyOffsets SQL: %w", err) @@ -604,7 +619,7 @@ func BlockHashSQL(schema, table string, primaryKeyCols []string, mode string) (s } // GetColumns retrieves the column names for a given table. -func GetColumns(ctx context.Context, db *pgxpool.Pool, schema, table string) ([]string, error) { +func GetColumns(ctx context.Context, db DBQuerier, schema, table string) ([]string, error) { sql, err := RenderSQL(SQLTemplates.GetColumns, nil) if err != nil { return nil, fmt.Errorf("failed to render GetColumns SQL: %w", err) @@ -636,12 +651,12 @@ func GetColumns(ctx context.Context, db *pgxpool.Pool, schema, table string) ([] return columns, nil } -func GetPrimaryKey(ctx context.Context, db *pgxpool.Pool, schema, table string) ([]string, error) { +func GetPrimaryKey(ctx context.Context, db DBQuerier, schema, table string) ([]string, error) { sql, err := RenderSQL(SQLTemplates.GetPrimaryKey, nil) if err != nil { return nil, err } - rows, err := db.Query(context.Background(), sql, schema, table) + rows, err := db.Query(ctx, sql, schema, table) if err != nil { return nil, err } @@ -667,7 +682,7 @@ func GetPrimaryKey(ctx context.Context, db *pgxpool.Pool, schema, table string) return keys, nil } -func GetColumnTypes(ctx context.Context, db *pgxpool.Pool, schema, table string) (map[string]string, error) { +func GetColumnTypes(ctx context.Context, db DBQuerier, schema, table string) (map[string]string, error) { sql, err := RenderSQL(SQLTemplates.GetColumnTypes, nil) if err != nil { return nil, err @@ -699,7 +714,7 @@ func GetColumnTypes(ctx context.Context, db *pgxpool.Pool, schema, table string) return types, nil } -func CheckUserPrivileges(ctx context.Context, db *pgxpool.Pool, username, schema, table string) (*types.UserPrivileges, error) { +func CheckUserPrivileges(ctx context.Context, db DBQuerier, username, schema, table string) (*types.UserPrivileges, error) { sql, err := RenderSQL(SQLTemplates.CheckUserPrivileges, nil) if err != nil { return nil, err @@ -723,7 +738,7 @@ func CheckUserPrivileges(ctx context.Context, db *pgxpool.Pool, username, schema return &privileges, nil } -func GetSpockNodeAndSubInfo(ctx context.Context, db *pgxpool.Pool) ([]types.SpockNodeAndSubInfo, error) { +func GetSpockNodeAndSubInfo(ctx context.Context, db DBQuerier) ([]types.SpockNodeAndSubInfo, error) { sql, err := RenderSQL(SQLTemplates.SpockNodeAndSubInfo, nil) if err != nil { return nil, err @@ -760,7 +775,7 @@ func GetSpockNodeAndSubInfo(ctx context.Context, db *pgxpool.Pool) ([]types.Spoc return infos, nil } -func GetSpockRepSetInfo(ctx context.Context, db *pgxpool.Pool) ([]types.SpockRepSetInfo, error) { +func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetInfo, error) { sql, err := RenderSQL(SQLTemplates.SpockRepSetInfo, nil) if err != nil { return nil, err @@ -791,7 +806,7 @@ func GetSpockRepSetInfo(ctx context.Context, db *pgxpool.Pool) ([]types.SpockRep return infos, nil } -func CheckSchemaExists(ctx context.Context, db *pgxpool.Pool, schema string) (bool, error) { +func CheckSchemaExists(ctx context.Context, db DBQuerier, schema string) (bool, error) { sql, err := RenderSQL(SQLTemplates.CheckSchemaExists, nil) if err != nil { return false, fmt.Errorf("failed to render CheckSchemaExists SQL: %w", err) @@ -806,7 +821,7 @@ func CheckSchemaExists(ctx context.Context, db *pgxpool.Pool, schema string) (bo return exists, nil } -func GetTablesInSchema(ctx context.Context, db *pgxpool.Pool, schema string) ([]string, error) { +func GetTablesInSchema(ctx context.Context, db DBQuerier, schema string) ([]string, error) { sql, err := RenderSQL(SQLTemplates.GetTablesInSchema, nil) if err != nil { return nil, fmt.Errorf("failed to render GetTablesInSchema SQL: %w", err) @@ -834,7 +849,7 @@ func GetTablesInSchema(ctx context.Context, db *pgxpool.Pool, schema string) ([] return tables, nil } -func GetViewsInSchema(ctx context.Context, db *pgxpool.Pool, schema string) ([]string, error) { +func GetViewsInSchema(ctx context.Context, db DBQuerier, schema string) ([]string, error) { sql, err := RenderSQL(SQLTemplates.GetViewsInSchema, nil) if err != nil { return nil, fmt.Errorf("failed to render GetViewsInSchema SQL: %w", err) @@ -862,7 +877,7 @@ func GetViewsInSchema(ctx context.Context, db *pgxpool.Pool, schema string) ([]s return views, nil } -func GetFunctionsInSchema(ctx context.Context, db *pgxpool.Pool, schema string) ([]string, error) { +func GetFunctionsInSchema(ctx context.Context, db DBQuerier, schema string) ([]string, error) { sql, err := RenderSQL(SQLTemplates.GetFunctionsInSchema, nil) if err != nil { return nil, fmt.Errorf("failed to render GetFunctionsInSchema SQL: %w", err) @@ -890,7 +905,7 @@ func GetFunctionsInSchema(ctx context.Context, db *pgxpool.Pool, schema string) return functions, nil } -func GetIndicesInSchema(ctx context.Context, db *pgxpool.Pool, schema string) ([]string, error) { +func GetIndicesInSchema(ctx context.Context, db DBQuerier, schema string) ([]string, error) { sql, err := RenderSQL(SQLTemplates.GetIndicesInSchema, nil) if err != nil { return nil, fmt.Errorf("failed to render GetIndicesInSchema SQL: %w", err) @@ -918,7 +933,7 @@ func GetIndicesInSchema(ctx context.Context, db *pgxpool.Pool, schema string) ([ return indices, nil } -func CheckRepSetExists(ctx context.Context, db *pgxpool.Pool, repSet string) (bool, error) { +func CheckRepSetExists(ctx context.Context, db DBQuerier, repSet string) (bool, error) { sql, err := RenderSQL(SQLTemplates.CheckRepSetExists, nil) if err != nil { return false, fmt.Errorf("failed to render CheckRepSetExists SQL: %w", err) @@ -933,7 +948,7 @@ func CheckRepSetExists(ctx context.Context, db *pgxpool.Pool, repSet string) (bo return exists, nil } -func GetTablesInRepSet(ctx context.Context, db *pgxpool.Pool, repSet string) ([]string, error) { +func GetTablesInRepSet(ctx context.Context, db DBQuerier, repSet string) ([]string, error) { sql, err := RenderSQL(SQLTemplates.GetTablesInRepSet, nil) if err != nil { return nil, fmt.Errorf("failed to render GetTablesInRepSet SQL: %w", err) @@ -961,7 +976,7 @@ func GetTablesInRepSet(ctx context.Context, db *pgxpool.Pool, repSet string) ([] return tables, nil } -func GetRowCountEstimate(ctx context.Context, db *pgxpool.Pool, schema, table string) (int64, error) { +func GetRowCountEstimate(ctx context.Context, db DBQuerier, schema, table string) (int64, error) { sql, err := RenderSQL(SQLTemplates.EstimateRowCount, nil) if err != nil { return 0, fmt.Errorf("failed to render EstimateRowCount SQL: %w", err) @@ -976,7 +991,7 @@ func GetRowCountEstimate(ctx context.Context, db *pgxpool.Pool, schema, table st return count, nil } -func GetPkeyColumnTypes(ctx context.Context, db *pgxpool.Pool, schema, table string, pkeys []string) (map[string]string, error) { +func GetPkeyColumnTypes(ctx context.Context, db DBQuerier, schema, table string, pkeys []string) (map[string]string, error) { sql, err := RenderSQL(SQLTemplates.GetPkeyColumnTypes, nil) if err != nil { return nil, fmt.Errorf("failed to render GetPkeyColumnTypes SQL: %w", err) @@ -1004,22 +1019,20 @@ func GetPkeyColumnTypes(ctx context.Context, db *pgxpool.Pool, schema, table str return types, nil } -func GetPkeyType(ctx context.Context, db *pgxpool.Pool, schema, table, pkey string) (string, error) { +func GetPkeyType(ctx context.Context, db DBQuerier, schema, table, pkey string) (string, error) { sql, err := RenderSQL(SQLTemplates.GetPkeyType, nil) if err != nil { return "", fmt.Errorf("failed to render GetPkeyType SQL: %w", err) } var pkeyType string - err = db.QueryRow(ctx, sql, schema, table, pkey).Scan(&pkeyType) - if err != nil { + if err := db.QueryRow(ctx, sql, schema, table, pkey).Scan(&pkeyType); err != nil { return "", fmt.Errorf("query to get pkey type for '%s.%s.%s' failed: %w", schema, table, pkey, err) } - return pkeyType, nil } -func UpdateMetadata(ctx context.Context, db *pgxpool.Pool, schema, table string, totalRows int64, blockSize, numBlocks int, isComposite bool) error { +func UpdateMetadata(ctx context.Context, db DBQuerier, schema, table string, totalRows int64, blockSize, numBlocks int, isComposite bool) error { sql, err := RenderSQL(SQLTemplates.UpdateMetadata, nil) if err != nil { return fmt.Errorf("failed to render UpdateMetadata SQL: %w", err) @@ -1033,19 +1046,16 @@ func UpdateMetadata(ctx context.Context, db *pgxpool.Pool, schema, table string, return nil } -func ComputeLeafHashes(ctx context.Context, db *pgxpool.Pool, schema, table string, simpleKey bool, key []string, start []any, end []any) ([]byte, error) { - sql, err := BlockHashSQL(schema, table, key, "MTREE_LEAF_HASH" /* mode */) +func ComputeLeafHashes(ctx context.Context, db DBQuerier, schema, table string, simpleKey bool, key []string, start []any, end []any) ([]byte, error) { + sql, err := BlockHashSQL(schema, table, key, "MTREE_LEAF_HASH") if err != nil { return nil, err } - // Build args: $1 skipMinCheck, [start values], $skipMaxCheck, [end values] - // When no start, set skipMinCheck=true; likewise for end args := make([]any, 0, 2+len(start)+len(end)) skipMin := len(start) == 0 || start[0] == nil args = append(args, skipMin) args = append(args, start...) - skipMax := len(end) == 0 || end[0] == nil args = append(args, skipMax) args = append(args, end...) @@ -1057,7 +1067,7 @@ func ComputeLeafHashes(ctx context.Context, db *pgxpool.Pool, schema, table stri return leafHash, nil } -func UpdateLeafHashes(ctx context.Context, db *pgxpool.Pool, mtreeTable string, leafHash []byte, nodePosition int64) (int64, error) { +func UpdateLeafHashes(ctx context.Context, db DBQuerier, mtreeTable string, leafHash []byte, nodePosition int64) (int64, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1076,7 +1086,7 @@ func UpdateLeafHashes(ctx context.Context, db *pgxpool.Pool, mtreeTable string, return updatedNodePosition, nil } -func GetBlockRanges(ctx context.Context, db *pgxpool.Pool, mtreeTable string) ([]types.BlockRange, error) { +func GetBlockRanges(ctx context.Context, db DBQuerier, mtreeTable string) ([]types.BlockRange, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1108,7 +1118,7 @@ func GetBlockRanges(ctx context.Context, db *pgxpool.Pool, mtreeTable string) ([ return blockRanges, nil } -func ClearDirtyFlags(ctx context.Context, db *pgxpool.Pool, mtreeTable string, nodePositions []int64) error { +func ClearDirtyFlags(ctx context.Context, db DBQuerier, mtreeTable string, nodePositions []int64) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1126,7 +1136,7 @@ func ClearDirtyFlags(ctx context.Context, db *pgxpool.Pool, mtreeTable string, n return nil } -func BuildParentNodes(ctx context.Context, db *pgxpool.Pool, mtreeTable string, nodeLevel int) (int, error) { +func BuildParentNodes(ctx context.Context, db DBQuerier, mtreeTable string, nodeLevel int) (int, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1145,7 +1155,7 @@ func BuildParentNodes(ctx context.Context, db *pgxpool.Pool, mtreeTable string, return count, nil } -func GetRootNode(ctx context.Context, db *pgxpool.Pool, mtreeTable string) (*types.RootNode, error) { +func GetRootNode(ctx context.Context, db DBQuerier, mtreeTable string) (*types.RootNode, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1164,7 +1174,7 @@ func GetRootNode(ctx context.Context, db *pgxpool.Pool, mtreeTable string) (*typ return &rootNode, nil } -func GetNodeChildren(ctx context.Context, db *pgxpool.Pool, mtreeTable string, nodeLevel, nodePosition int) ([]types.NodeChild, error) { +func GetNodeChildren(ctx context.Context, db DBQuerier, mtreeTable string, nodeLevel, nodePosition int) ([]types.NodeChild, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1196,7 +1206,7 @@ func GetNodeChildren(ctx context.Context, db *pgxpool.Pool, mtreeTable string, n return children, nil } -func GetLeafRanges(ctx context.Context, db *pgxpool.Pool, mtreeTable string, nodePositions []int64, simplePrimaryKey bool, key []string) ([]types.LeafRange, error) { +func GetLeafRanges(ctx context.Context, db DBQuerier, mtreeTable string, nodePositions []int64, simplePrimaryKey bool, key []string) ([]types.LeafRange, error) { if simplePrimaryKey { data := map[string]interface{}{ "MtreeTable": mtreeTable, @@ -1282,7 +1292,7 @@ func GetLeafRanges(ctx context.Context, db *pgxpool.Pool, mtreeTable string, nod return ranges, nil } -func GetRowCountEstimateFromMetadata(ctx context.Context, db *pgxpool.Pool, schema, table string) (int64, error) { +func GetRowCountEstimateFromMetadata(ctx context.Context, db DBQuerier, schema, table string) (int64, error) { sql, err := RenderSQL(SQLTemplates.GetRowCountEstimate, nil) if err != nil { return 0, fmt.Errorf("failed to render GetRowCountEstimate SQL: %w", err) @@ -1297,34 +1307,46 @@ func GetRowCountEstimateFromMetadata(ctx context.Context, db *pgxpool.Pool, sche return count, nil } -func GetMaxValComposite(ctx context.Context, db *pgxpool.Pool, schema, table, pkeyCols, pkeyValues string) ([]interface{}, error) { +func GetMaxValComposite(ctx context.Context, db DBQuerier, schema, table string, pkeyCols []string, pkeyValues []any) ([]interface{}, error) { + cols := make([]string, len(pkeyCols)) + for i, c := range pkeyCols { + cols[i] = pgx.Identifier{c}.Sanitize() + } + colsStr := strings.Join(cols, ", ") + + valsPh := make([]string, len(pkeyValues)) + args := make([]any, len(pkeyValues)) + for i, v := range pkeyValues { + valsPh[i] = fmt.Sprintf("$%d", i+1) + args[i] = v + } + valsStr := strings.Join(valsPh, ", ") + data := map[string]interface{}{ "SchemaIdent": pgx.Identifier{schema}.Sanitize(), "TableIdent": pgx.Identifier{table}.Sanitize(), - "PkeyCols": pkeyCols, - "PkeyValues": pkeyValues, + "PkeyCols": colsStr, + "PkeyValues": fmt.Sprintf("ROW(%s)", valsStr), } - sql, err := RenderSQL(SQLTemplates.GetMaxValComposite, data) if err != nil { return nil, fmt.Errorf("failed to render GetMaxValComposite SQL: %w", err) } - - // Scan each key attribute into separate destinations - // Note: pkeyCols is already a comma-separated list of sanitised identifiers - numCols := strings.Count(pkeyCols, ",") + 1 - dest := make([]interface{}, numCols) - destPtrs := make([]interface{}, numCols) + dest := make([]interface{}, len(pkeyCols)) + destPtrs := make([]interface{}, len(pkeyCols)) for i := range destPtrs { destPtrs[i] = &dest[i] } - if err := db.QueryRow(ctx, sql).Scan(destPtrs...); err != nil { + if err := db.QueryRow(ctx, sql, args...).Scan(destPtrs...); err != nil { + if err == pgx.ErrNoRows { + return nil, nil + } return nil, fmt.Errorf("query to get max val composite for '%s.%s' failed: %w", schema, table, err) } return dest, nil } -func UpdateMaxVal(ctx context.Context, db *pgxpool.Pool, mtreeTable string, rangeEnd interface{}, nodePosition int64) error { +func UpdateMaxVal(ctx context.Context, db DBQuerier, mtreeTable string, rangeEnd interface{}, nodePosition int64) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1342,7 +1364,7 @@ func UpdateMaxVal(ctx context.Context, db *pgxpool.Pool, mtreeTable string, rang return nil } -func GetMaxValSimple(ctx context.Context, db *pgxpool.Pool, schema, table, key string, rangeStart interface{}) (interface{}, error) { +func GetMaxValSimple(ctx context.Context, db DBQuerier, schema, table, key string, rangeStart interface{}) (interface{}, error) { data := map[string]interface{}{ "SchemaIdent": pgx.Identifier{schema}.Sanitize(), "TableIdent": pgx.Identifier{table}.Sanitize(), @@ -1363,7 +1385,7 @@ func GetMaxValSimple(ctx context.Context, db *pgxpool.Pool, schema, table, key s return maxVal, nil } -func GetCountComposite(ctx context.Context, db *pgxpool.Pool, schema, table, whereClause string) (int64, error) { +func GetCountComposite(ctx context.Context, db DBQuerier, schema, table, whereClause string) (int64, error) { data := map[string]interface{}{ "SchemaIdent": pgx.Identifier{schema}.Sanitize(), "TableIdent": pgx.Identifier{table}.Sanitize(), @@ -1384,7 +1406,7 @@ func GetCountComposite(ctx context.Context, db *pgxpool.Pool, schema, table, whe return count, nil } -func GetCountSimple(ctx context.Context, db *pgxpool.Pool, schema, table, key, pkeyType string, rangeStart, rangeEnd interface{}) (int64, error) { +func GetCountSimple(ctx context.Context, db DBQuerier, schema, table, key, pkeyType string, rangeStart, rangeEnd interface{}) (int64, error) { data := map[string]interface{}{ "SchemaIdent": pgx.Identifier{schema}.Sanitize(), "TableIdent": pgx.Identifier{table}.Sanitize(), @@ -1406,7 +1428,7 @@ func GetCountSimple(ctx context.Context, db *pgxpool.Pool, schema, table, key, p return count, nil } -func GetBlockRowCount(ctx context.Context, tx pgx.Tx, schema string, table string, keyColumns []string, isComposite bool, start, end []any) (int64, error) { +func GetBlockRowCount(ctx context.Context, db DBQuerier, schema string, table string, keyColumns []string, isComposite bool, start, end []any) (int64, error) { var whereClause string var args []any @@ -1459,7 +1481,7 @@ func GetBlockRowCount(ctx context.Context, tx pgx.Tx, schema string, table strin } var count int64 - err = tx.QueryRow(ctx, sql, args...).Scan(&count) + err = db.QueryRow(ctx, sql, args...).Scan(&count) if err != nil { return 0, fmt.Errorf("query to get block row count for '%s.%s' failed: %w", schema, table, err) } @@ -1467,60 +1489,7 @@ func GetBlockRowCount(ctx context.Context, tx pgx.Tx, schema string, table strin return count, nil } -func FindBlocksToSplit(ctx context.Context, conn *pgx.Conn, mtreeTable string, insertsSinceUpdate int, nodePositions []int64, simplePrimaryKey bool) ([]types.BlockRange, error) { - data := map[string]interface{}{ - "MtreeTable": mtreeTable, - } - sql, err := RenderSQL(SQLTemplates.FindBlocksToSplit, data) - if err != nil { - return nil, fmt.Errorf("failed to render FindBlocksToSplit SQL: %w", err) - } - - rows, err := conn.Query(ctx, sql, insertsSinceUpdate, nodePositions) - if err != nil { - return nil, fmt.Errorf("query to find blocks to split for '%s' failed: %w", mtreeTable, err) - } - defer rows.Close() - - var blocks []types.BlockRange - for rows.Next() { - var br types.BlockRange - if simplePrimaryKey { - var start, end any - if err := rows.Scan(&br.NodePosition, &start, &end); err != nil { - return nil, fmt.Errorf("failed to scan block to split: %w", err) - } - if start != nil { - br.RangeStart = []any{start} - } - if end != nil { - br.RangeEnd = []any{end} - } - } else { - var start, end pgtype.CompositeType - if err := rows.Scan(&br.NodePosition, &start, &end); err != nil { - return nil, fmt.Errorf("failed to scan block to split: %w", err) - } - if start.Get() != nil { - var values []any - start.AssignTo(&values) - br.RangeStart = values - } - if end.Get() != nil { - var values []any - end.AssignTo(&values) - br.RangeEnd = values - } - } - blocks = append(blocks, br) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating over blocks to split: %w", err) - } - return blocks, nil -} - -func GetDirtyAndNewBlocksTx(ctx context.Context, tx pgx.Tx, mtreeTable string, simplePrimaryKey bool, key []string) ([]types.BlockRange, error) { +func GetDirtyAndNewBlocks(ctx context.Context, db DBQuerier, mtreeTable string, simplePrimaryKey bool, key []string) ([]types.BlockRange, error) { if simplePrimaryKey { data := map[string]interface{}{ "MtreeTable": mtreeTable, @@ -1529,7 +1498,7 @@ func GetDirtyAndNewBlocksTx(ctx context.Context, tx pgx.Tx, mtreeTable string, s if err != nil { return nil, fmt.Errorf("failed to render GetDirtyAndNewBlocks SQL: %w", err) } - rows, err := tx.Query(ctx, sql) + rows, err := db.Query(ctx, sql) if err != nil { return nil, fmt.Errorf("query to get dirty and new blocks for '%s' failed: %w", mtreeTable, err) } @@ -1575,7 +1544,7 @@ func GetDirtyAndNewBlocksTx(ctx context.Context, tx pgx.Tx, mtreeTable string, s return nil, fmt.Errorf("failed to render GetDirtyAndNewBlocksExpanded SQL: %w", err) } - rows, err := tx.Query(ctx, sql) + rows, err := db.Query(ctx, sql) if err != nil { return nil, fmt.Errorf("query to get dirty and new blocks for '%s' failed: %w", mtreeTable, err) } @@ -1620,7 +1589,7 @@ func GetDirtyAndNewBlocksTx(ctx context.Context, tx pgx.Tx, mtreeTable string, s return blocks, nil } -func FindBlocksToSplitTx(ctx context.Context, tx pgx.Tx, mtreeTable string, insertsSinceUpdate int, nodePositions []int64, simplePrimaryKey bool, key []string) ([]types.BlockRange, error) { +func FindBlocksToSplit(ctx context.Context, db DBQuerier, mtreeTable string, insertsSinceUpdate int, nodePositions []int64, simplePrimaryKey bool, key []string) ([]types.BlockRange, error) { if simplePrimaryKey { data := map[string]interface{}{ "MtreeTable": mtreeTable, @@ -1629,7 +1598,7 @@ func FindBlocksToSplitTx(ctx context.Context, tx pgx.Tx, mtreeTable string, inse if err != nil { return nil, fmt.Errorf("failed to render FindBlocksToSplit SQL: %w", err) } - rows, err := tx.Query(ctx, sql, insertsSinceUpdate, nodePositions) + rows, err := db.Query(ctx, sql, insertsSinceUpdate, nodePositions) if err != nil { return nil, fmt.Errorf("query to find blocks to split for '%s' failed: %w", mtreeTable, err) } @@ -1675,7 +1644,7 @@ func FindBlocksToSplitTx(ctx context.Context, tx pgx.Tx, mtreeTable string, inse return nil, fmt.Errorf("failed to render FindBlocksToSplitExpanded SQL: %w", err) } - rows, err := tx.Query(ctx, sql, insertsSinceUpdate, nodePositions) + rows, err := db.Query(ctx, sql, insertsSinceUpdate, nodePositions) if err != nil { return nil, fmt.Errorf("query to find blocks to split for '%s' failed: %w", mtreeTable, err) } @@ -1719,7 +1688,7 @@ func FindBlocksToSplitTx(ctx context.Context, tx pgx.Tx, mtreeTable string, inse return blocks, nil } -func GetMaxNodePositionTx(ctx context.Context, tx pgx.Tx, mtreeTable string) (int64, error) { +func GetMaxNodePosition(ctx context.Context, db DBQuerier, mtreeTable string) (int64, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -1728,13 +1697,13 @@ func GetMaxNodePositionTx(ctx context.Context, tx pgx.Tx, mtreeTable string) (in return 0, fmt.Errorf("failed to render GetMaxNodePosition SQL: %w", err) } var pos int64 - if err := tx.QueryRow(ctx, sql).Scan(&pos); err != nil { + if err := db.QueryRow(ctx, sql).Scan(&pos); err != nil { return 0, fmt.Errorf("query to get max node position for '%s' failed: %w", mtreeTable, err) } return pos, nil } -func UpdateBlockRangeEndTx(ctx context.Context, tx pgx.Tx, mtreeTable string, rangeEnd any, nodePosition int64) error { +func UpdateBlockRangeEnd(ctx context.Context, db DBQuerier, mtreeTable string, rangeEnd any, nodePosition int64) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, "RangeEndExpr": "$1", @@ -1744,69 +1713,13 @@ func UpdateBlockRangeEndTx(ctx context.Context, tx pgx.Tx, mtreeTable string, ra if err != nil { return fmt.Errorf("failed to render UpdateBlockRangeEnd SQL: %w", err) } - if _, err := tx.Exec(ctx, sql, rangeEnd, nodePosition); err != nil { + if _, err := db.Exec(ctx, sql, rangeEnd, nodePosition); err != nil { return fmt.Errorf("query to update block range end for '%s' failed: %w", mtreeTable, err) } return nil } -func InsertBlockRangesTx(ctx context.Context, tx pgx.Tx, mtreeTable string, nodePosition int64, rangeStart, rangeEnd interface{}) error { - data := map[string]interface{}{ - "MtreeTable": mtreeTable, - } - sql, err := RenderSQL(SQLTemplates.InsertBlockRanges, data) - if err != nil { - return fmt.Errorf("failed to render InsertBlockRanges SQL: %w", err) - } - if _, err := tx.Exec(ctx, sql, nodePosition, rangeStart, rangeEnd); err != nil { - return fmt.Errorf("query to insert block ranges for '%s' failed: %w", mtreeTable, err) - } - return nil -} - -func InsertCompositeBlockRangesTx(ctx context.Context, tx pgx.Tx, mtreeTable string, nodePosition int64, startVals, endVals []any) error { - startPh := make([]string, len(startVals)) - args := make([]any, 0, 1+len(startVals)+len(endVals)) - args = append(args, nodePosition) - argIdx := 2 - for i := range startVals { - startPh[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, startVals[i]) - argIdx++ - } - - var endExpr string - if endVals == nil { - endExpr = "NULL" - } else { - endPh := make([]string, len(endVals)) - for i := range endVals { - endPh[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, endVals[i]) - argIdx++ - } - endExpr = fmt.Sprintf("ROW(%s)", strings.Join(endPh, ", ")) - } - startExpr := fmt.Sprintf("ROW(%s)", strings.Join(startPh, ", ")) - - data := map[string]interface{}{ - "MtreeTable": mtreeTable, - "StartExpr": startExpr, - "EndExpr": endExpr, - } - - stmt, err := RenderSQL(SQLTemplates.InsertCompositeBlockRanges, data) - if err != nil { - return fmt.Errorf("failed to render InsertCompositeBlockRanges SQL: %w", err) - } - - if _, err := tx.Exec(ctx, stmt, args...); err != nil { - return fmt.Errorf("query to insert composite block ranges for '%s' failed: %w", mtreeTable, err) - } - return nil -} - -func UpdateBlockRangeEndCompositeTx(ctx context.Context, tx pgx.Tx, mtreeTable string, compositeTypeName string, endVals []any, pos int64) error { +func UpdateBlockRangeEndComposite(ctx context.Context, db DBQuerier, mtreeTable string, compositeTypeName string, endVals []any, pos int64) error { args := []any{} isNull := len(endVals) == 0 @@ -1832,13 +1745,13 @@ func UpdateBlockRangeEndCompositeTx(ctx context.Context, tx pgx.Tx, mtreeTable s return fmt.Errorf("failed to render UpdateBlockRangeEndCompositeTx SQL: %w", err) } - if _, err := tx.Exec(ctx, sql, args...); err != nil { + if _, err := db.Exec(ctx, sql, args...); err != nil { return fmt.Errorf("query to update composite block range end for '%s' failed: %w", mtreeTable, err) } return nil } -func UpdateBlockRangeStartTx(ctx context.Context, tx pgx.Tx, mtreeTable string, rangeStart any, nodePosition int64) error { +func UpdateBlockRangeStart(ctx context.Context, db DBQuerier, mtreeTable string, rangeStart any, nodePosition int64) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, "RangeStartExpr": "$1", @@ -1848,13 +1761,13 @@ func UpdateBlockRangeStartTx(ctx context.Context, tx pgx.Tx, mtreeTable string, if err != nil { return fmt.Errorf("failed to render UpdateBlockRangeStart SQL: %w", err) } - if _, err := tx.Exec(ctx, sql, rangeStart, nodePosition); err != nil { + if _, err := db.Exec(ctx, sql, rangeStart, nodePosition); err != nil { return fmt.Errorf("query to update block range start for '%s' failed: %w", mtreeTable, err) } return nil } -func UpdateBlockRangeStartCompositeTx(ctx context.Context, tx pgx.Tx, mtreeTable string, compositeTypeName string, startVals []any, pos int64) error { +func UpdateBlockRangeStartComposite(ctx context.Context, db DBQuerier, mtreeTable string, compositeTypeName string, startVals []any, pos int64) error { args := []any{} isNull := len(startVals) == 0 @@ -1880,13 +1793,13 @@ func UpdateBlockRangeStartCompositeTx(ctx context.Context, tx pgx.Tx, mtreeTable return fmt.Errorf("failed to render UpdateBlockRangeStartCompositeTx SQL: %w", err) } - if _, err := tx.Exec(ctx, sql, args...); err != nil { + if _, err := db.Exec(ctx, sql, args...); err != nil { return fmt.Errorf("query to update composite block range start for '%s' failed: %w", mtreeTable, err) } return nil } -func GetSplitPointSimpleTx(ctx context.Context, tx pgx.Tx, schema, table, key, pkeyType string, rangeStart, rangeEnd interface{}, offset int64) (interface{}, error) { +func GetSplitPointSimple(ctx context.Context, db DBQuerier, schema, table, key, pkeyType string, rangeStart, rangeEnd interface{}, offset int64) (interface{}, error) { data := map[string]interface{}{ "SchemaIdent": pgx.Identifier{schema}.Sanitize(), "TableIdent": pgx.Identifier{table}.Sanitize(), @@ -1898,13 +1811,13 @@ func GetSplitPointSimpleTx(ctx context.Context, tx pgx.Tx, schema, table, key, p return nil, fmt.Errorf("failed to render GetSplitPointSimple SQL: %w", err) } var splitPoint interface{} - if err := tx.QueryRow(ctx, sql, rangeStart, rangeEnd, pkeyType, offset).Scan(&splitPoint); err != nil { + if err := db.QueryRow(ctx, sql, rangeStart, rangeEnd, pkeyType, offset).Scan(&splitPoint); err != nil { return nil, fmt.Errorf("query to get split point simple for '%s.%s' failed: %w", schema, table, err) } return splitPoint, nil } -func GetSplitPointCompositeTx(ctx context.Context, tx pgx.Tx, schema, table string, pkeyCols []string, startVals, endVals []any, offset int64) ([]interface{}, error) { +func GetSplitPointComposite(ctx context.Context, db DBQuerier, schema, table string, pkeyCols []string, startVals, endVals []any, offset int64) ([]interface{}, error) { cols := make([]string, len(pkeyCols)) for i, c := range pkeyCols { cols[i] = pgx.Identifier{c}.Sanitize() @@ -1971,7 +1884,7 @@ func GetSplitPointCompositeTx(ctx context.Context, tx pgx.Tx, schema, table stri destPtrs[i] = &sp[i] } - if err := tx.QueryRow(ctx, sql, args...).Scan(destPtrs...); err != nil { + if err := db.QueryRow(ctx, sql, args...).Scan(destPtrs...); err != nil { if err == pgx.ErrNoRows { return nil, nil } @@ -1980,63 +1893,7 @@ func GetSplitPointCompositeTx(ctx context.Context, tx pgx.Tx, schema, table stri return sp, nil } -func GetMaxValSimpleTx(ctx context.Context, tx pgx.Tx, schema, table, key string, rangeStart interface{}) (interface{}, error) { - data := map[string]interface{}{ - "SchemaIdent": pgx.Identifier{schema}.Sanitize(), - "TableIdent": pgx.Identifier{table}.Sanitize(), - "Key": key, - } - sql, err := RenderSQL(SQLTemplates.GetMaxValSimple, data) - if err != nil { - return nil, fmt.Errorf("failed to render GetMaxValSimple SQL: %w", err) - } - var maxVal interface{} - if err := tx.QueryRow(ctx, sql, rangeStart).Scan(&maxVal); err != nil { - return nil, fmt.Errorf("query to get max val simple for '%s.%s' failed: %w", schema, table, err) - } - return maxVal, nil -} - -func GetMaxValCompositeTx(ctx context.Context, tx pgx.Tx, schema, table string, pkeyCols []string, pkeyValues []any) ([]interface{}, error) { - cols := make([]string, len(pkeyCols)) - for i, c := range pkeyCols { - cols[i] = pgx.Identifier{c}.Sanitize() - } - colsStr := strings.Join(cols, ", ") - - valsPh := make([]string, len(pkeyValues)) - args := make([]any, len(pkeyValues)) - for i, v := range pkeyValues { - valsPh[i] = fmt.Sprintf("$%d", i+1) - args[i] = v - } - valsStr := strings.Join(valsPh, ", ") - - data := map[string]interface{}{ - "SchemaIdent": pgx.Identifier{schema}.Sanitize(), - "TableIdent": pgx.Identifier{table}.Sanitize(), - "PkeyCols": colsStr, - "PkeyValues": fmt.Sprintf("ROW(%s)", valsStr), - } - sql, err := RenderSQL(SQLTemplates.GetMaxValComposite, data) - if err != nil { - return nil, fmt.Errorf("failed to render GetMaxValComposite SQL: %w", err) - } - dest := make([]interface{}, len(pkeyCols)) - destPtrs := make([]interface{}, len(pkeyCols)) - for i := range destPtrs { - destPtrs[i] = &dest[i] - } - if err := tx.QueryRow(ctx, sql, args...).Scan(destPtrs...); err != nil { - if err == pgx.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("query to get max val composite for '%s.%s' failed: %w", schema, table, err) - } - return dest, nil -} - -func GetMinValCompositeTx(ctx context.Context, tx pgx.Tx, schema, table string, pkeyCols []string) ([]interface{}, error) { +func GetMinValComposite(ctx context.Context, db DBQuerier, schema, table string, pkeyCols []string) ([]interface{}, error) { cols := make([]string, len(pkeyCols)) for i, c := range pkeyCols { cols[i] = pgx.Identifier{c}.Sanitize() @@ -2058,7 +1915,7 @@ func GetMinValCompositeTx(ctx context.Context, tx pgx.Tx, schema, table string, for i := range destPtrs { destPtrs[i] = &dest[i] } - if err := tx.QueryRow(ctx, sql).Scan(destPtrs...); err != nil { + if err := db.QueryRow(ctx, sql).Scan(destPtrs...); err != nil { if err == pgx.ErrNoRows { return nil, nil } @@ -2067,7 +1924,7 @@ func GetMinValCompositeTx(ctx context.Context, tx pgx.Tx, schema, table string, return dest, nil } -func GetMinValSimpleTx(ctx context.Context, tx pgx.Tx, schema, table, key string) (interface{}, error) { +func GetMinValSimple(ctx context.Context, db DBQuerier, schema, table, key string) (interface{}, error) { data := map[string]interface{}{ "SchemaIdent": pgx.Identifier{schema}.Sanitize(), "TableIdent": pgx.Identifier{table}.Sanitize(), @@ -2078,7 +1935,7 @@ func GetMinValSimpleTx(ctx context.Context, tx pgx.Tx, schema, table, key string return nil, fmt.Errorf("failed to render GetMinValSimple SQL: %w", err) } var minVal interface{} - if err := tx.QueryRow(ctx, sql).Scan(&minVal); err != nil { + if err := db.QueryRow(ctx, sql).Scan(&minVal); err != nil { if err == pgx.ErrNoRows { return nil, nil } @@ -2087,70 +1944,15 @@ func GetMinValSimpleTx(ctx context.Context, tx pgx.Tx, schema, table, key string return minVal, nil } -func ClearDirtyFlagsTx(ctx context.Context, tx pgx.Tx, mtreeTable string, nodePositions []int64) error { - data := map[string]interface{}{ - "MtreeTable": mtreeTable, - } - sql, err := RenderSQL(SQLTemplates.ClearDirtyFlags, data) - if err != nil { - return fmt.Errorf("failed to render ClearDirtyFlags SQL: %w", err) - } - if _, err := tx.Exec(ctx, sql, nodePositions); err != nil { - return fmt.Errorf("query to clear dirty flags for '%s' failed: %w", mtreeTable, err) - } - return nil -} - -func DeleteParentNodesTx(ctx context.Context, tx pgx.Tx, mtreeTable string) error { - data := map[string]interface{}{ - "MtreeTable": mtreeTable, - } - sql, err := RenderSQL(SQLTemplates.DeleteParentNodes, data) - if err != nil { - return fmt.Errorf("failed to render DeleteParentNodes SQL: %w", err) - } - if _, err := tx.Exec(ctx, sql); err != nil { - return fmt.Errorf("query to delete parent nodes for '%s' failed: %w", mtreeTable, err) - } - return nil -} - -func BuildParentNodesTx(ctx context.Context, tx pgx.Tx, mtreeTable string, nodeLevel int) (int, error) { - data := map[string]interface{}{ - "MtreeTable": mtreeTable, - } - sql, err := RenderSQL(SQLTemplates.BuildParentNodes, data) - if err != nil { - return 0, fmt.Errorf("failed to render BuildParentNodes SQL: %w", err) - } - var count int - if err := tx.QueryRow(ctx, sql, nodeLevel).Scan(&count); err != nil { - return 0, fmt.Errorf("query to build parent nodes for '%s' failed: %w", mtreeTable, err) - } - return count, nil -} - -func GetPkeyTypeTx(ctx context.Context, tx pgx.Tx, schema, table, pkey string) (string, error) { - sql, err := RenderSQL(SQLTemplates.GetPkeyType, nil) - if err != nil { - return "", fmt.Errorf("failed to render GetPkeyType SQL: %w", err) - } - var pkeyType string - if err := tx.QueryRow(ctx, sql, schema, table, pkey).Scan(&pkeyType); err != nil { - return "", fmt.Errorf("query to get pkey type for '%s.%s.%s' failed: %w", schema, table, pkey, err) - } - return pkeyType, nil -} - -func FindBlocksToMergeComposite(ctx context.Context, db *pgxpool.Pool, mtreeTable, schema, table string, keyColumns []string, nodePositions []int64, mergeThreshold float64) ([]types.BlockRange, error) { +func FindBlocksToMergeComposite(ctx context.Context, db DBQuerier, mtreeTable, schema, table string, keyColumns []string, nodePositions []int64, mergeThreshold float64) ([]types.BlockRange, error) { return findBlocksToMerge(ctx, db, mtreeTable, schema, table, keyColumns, false, nodePositions, mergeThreshold) } -func FindBlocksToMergeSimple(ctx context.Context, db *pgxpool.Pool, mtreeTable, schema, table, key string, nodePositions []int64, mergeThreshold float64) ([]types.BlockRange, error) { +func FindBlocksToMergeSimple(ctx context.Context, db DBQuerier, mtreeTable, schema, table, key string, nodePositions []int64, mergeThreshold float64) ([]types.BlockRange, error) { return findBlocksToMerge(ctx, db, mtreeTable, schema, table, []string{key}, true, nodePositions, mergeThreshold) } -func findBlocksToMerge(ctx context.Context, db DBTX, mtreeTable, schema, table string, key []string, simplePrimaryKey bool, nodePositions []int64, mergeThreshold float64) ([]types.BlockRange, error) { +func findBlocksToMerge(ctx context.Context, db DBQuerier, mtreeTable, schema, table string, key []string, simplePrimaryKey bool, nodePositions []int64, mergeThreshold float64) ([]types.BlockRange, error) { var queryArgs []any usePositionFilter := len(nodePositions) > 0 @@ -2283,7 +2085,7 @@ func findBlocksToMerge(ctx context.Context, db DBTX, mtreeTable, schema, table s return blocks, nil } -func GetBlockCountComposite(ctx context.Context, db *pgxpool.Pool, mtreeTable, schema, table, pkeyCols string, nodePosition int64) (*types.BlockCountComposite, error) { +func GetBlockCountComposite(ctx context.Context, db DBQuerier, mtreeTable, schema, table, pkeyCols string, nodePosition int64) (*types.BlockCountComposite, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, "SchemaIdent": pgx.Identifier{schema}.Sanitize(), @@ -2305,7 +2107,7 @@ func GetBlockCountComposite(ctx context.Context, db *pgxpool.Pool, mtreeTable, s return &blockCount, nil } -func GetBlockCountSimple(ctx context.Context, db *pgxpool.Pool, mtreeTable, schema, table, key string, nodePosition int64) (*types.BlockCountSimple, error) { +func GetBlockCountSimple(ctx context.Context, db DBQuerier, mtreeTable, schema, table, key string, nodePosition int64) (*types.BlockCountSimple, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, "SchemaIdent": pgx.Identifier{schema}.Sanitize(), @@ -2327,21 +2129,21 @@ func GetBlockCountSimple(ctx context.Context, db *pgxpool.Pool, mtreeTable, sche return &blockCount, nil } -func GetBlockSizeFromMetadata(ctx context.Context, pool DBTX, schema, table string) (int, error) { +func GetBlockSizeFromMetadata(ctx context.Context, db DBQuerier, schema, table string) (int, error) { data := map[string]interface{}{} query, err := RenderSQL(SQLTemplates.GetBlockSizeFromMetadata, data) if err != nil { return 0, fmt.Errorf("failed to render GetBlockSizeFromMetadata SQL: %w", err) } var blockSize int - err = pool.QueryRow(ctx, query, schema, table).Scan(&blockSize) + err = db.QueryRow(ctx, query, schema, table).Scan(&blockSize) if err != nil { return 0, fmt.Errorf("query to get block size from metadata for '%s.%s' failed: %w", schema, table, err) } return blockSize, nil } -func GetMaxNodeLevel(ctx context.Context, pool DBTX, mtreeTable string) (int, error) { +func GetMaxNodeLevel(ctx context.Context, db DBQuerier, mtreeTable string) (int, error) { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -2352,7 +2154,7 @@ func GetMaxNodeLevel(ctx context.Context, pool DBTX, mtreeTable string) (int, er } var maxLevel int - err = pool.QueryRow(ctx, sql).Scan(&maxLevel) + err = db.QueryRow(ctx, sql).Scan(&maxLevel) if err != nil { return 0, fmt.Errorf("query to get max node level for '%s' failed: %w", mtreeTable, err) } @@ -2360,13 +2162,13 @@ func GetMaxNodeLevel(ctx context.Context, pool DBTX, mtreeTable string) (int, er return maxLevel, nil } -func DropXORFunction(ctx context.Context, pool DBTX) error { +func DropXORFunction(ctx context.Context, db DBQuerier) error { sql, err := RenderSQL(SQLTemplates.DropXORFunction, nil) if err != nil { return fmt.Errorf("failed to render DropXORFunction SQL: %w", err) } - _, err = pool.Exec(ctx, sql) + _, err = db.Exec(ctx, sql) if err != nil { return fmt.Errorf("query to drop xor function failed: %w", err) } @@ -2374,13 +2176,13 @@ func DropXORFunction(ctx context.Context, pool DBTX) error { return nil } -func DropMetadataTable(ctx context.Context, pool DBTX) error { +func DropMetadataTable(ctx context.Context, db DBQuerier) error { sql, err := RenderSQL(SQLTemplates.DropMetadataTable, nil) if err != nil { return fmt.Errorf("failed to render DropMetadataTable SQL: %w", err) } - _, err = pool.Exec(ctx, sql) + _, err = db.Exec(ctx, sql) if err != nil { return fmt.Errorf("query to drop metadata table failed: %w", err) } @@ -2388,7 +2190,7 @@ func DropMetadataTable(ctx context.Context, pool DBTX) error { return nil } -func DropMtreeTable(ctx context.Context, pool DBTX, mtreeTable string) error { +func DropMtreeTable(ctx context.Context, db DBQuerier, mtreeTable string) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -2398,7 +2200,7 @@ func DropMtreeTable(ctx context.Context, pool DBTX, mtreeTable string) error { return fmt.Errorf("failed to render DropMtreeTable SQL: %w", err) } - _, err = pool.Exec(ctx, sql) + _, err = db.Exec(ctx, sql) if err != nil { return fmt.Errorf("query to drop mtree table for '%s' failed: %w", mtreeTable, err) } @@ -2406,7 +2208,7 @@ func DropMtreeTable(ctx context.Context, pool DBTX, mtreeTable string) error { return nil } -func DeleteParentNodes(ctx context.Context, pool DBTX, mtreeTable string) error { +func DeleteParentNodes(ctx context.Context, db DBQuerier, mtreeTable string) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, } @@ -2416,7 +2218,7 @@ func DeleteParentNodes(ctx context.Context, pool DBTX, mtreeTable string) error return fmt.Errorf("failed to render DeleteParentNodes SQL: %w", err) } - _, err = pool.Exec(ctx, sql) + _, err = db.Exec(ctx, sql) if err != nil { return fmt.Errorf("query to delete parent nodes for '%s' failed: %w", mtreeTable, err) } @@ -2424,11 +2226,11 @@ func DeleteParentNodes(ctx context.Context, pool DBTX, mtreeTable string) error return nil } -func FindBlocksToMergeTx(ctx context.Context, tx pgx.Tx, mtreeTableName string, simplePrimaryKey bool, schema string, table string, key []string, mergeThreshold float64, blockPositions []int64) ([]types.BlockRange, error) { - return findBlocksToMerge(ctx, tx, mtreeTableName, schema, table, key, simplePrimaryKey, blockPositions, mergeThreshold) +func FindBlocksToMerge(ctx context.Context, db DBQuerier, mtreeTableName string, simplePrimaryKey bool, schema string, table string, key []string, mergeThreshold float64, blockPositions []int64) ([]types.BlockRange, error) { + return findBlocksToMerge(ctx, db, mtreeTableName, schema, table, key, simplePrimaryKey, blockPositions, mergeThreshold) } -func GetBlockWithCountTx(ctx context.Context, tx pgx.Tx, mtreeTable, schema, table string, key []string, isComposite bool, position int64) (*types.BlockRangeWithCount, error) { +func GetBlockWithCount(ctx context.Context, db DBQuerier, mtreeTable, schema, table string, key []string, isComposite bool, position int64) (*types.BlockRangeWithCount, error) { sanitizedKeys := make([]string, len(key)) for i, k := range key { sanitizedKeys[i] = pgx.Identifier{k}.Sanitize() @@ -2467,7 +2269,7 @@ func GetBlockWithCountTx(ctx context.Context, tx pgx.Tx, mtreeTable, schema, tab var count int64 var start, end any - row := tx.QueryRow(ctx, query, position) + row := db.QueryRow(ctx, query, position) if isComposite { // node_position, start attrs..., end attrs..., count @@ -2521,7 +2323,7 @@ func GetBlockWithCountTx(ctx context.Context, tx pgx.Tx, mtreeTable, schema, tab return &block, nil } -func UpdateNodePositionTx(ctx context.Context, tx pgx.Tx, mtreeTable string, oldPosition, newPosition int64) error { +func UpdateNodePosition(ctx context.Context, db DBQuerier, mtreeTable string, oldPosition, newPosition int64) error { data := map[string]interface{}{ "MtreeTable": pgx.Identifier{mtreeTable}.Sanitize(), } @@ -2529,11 +2331,11 @@ func UpdateNodePositionTx(ctx context.Context, tx pgx.Tx, mtreeTable string, old if err != nil { return fmt.Errorf("failed to render UpdateNodePosition SQL: %w", err) } - _, err = tx.Exec(ctx, query, newPosition, oldPosition) + _, err = db.Exec(ctx, query, newPosition, oldPosition) return err } -func DeleteBlockTx(ctx context.Context, tx pgx.Tx, mtreeTable string, position int64) error { +func DeleteBlock(ctx context.Context, db DBQuerier, mtreeTable string, position int64) error { data := map[string]interface{}{ "MtreeTable": pgx.Identifier{mtreeTable}.Sanitize(), } @@ -2541,11 +2343,11 @@ func DeleteBlockTx(ctx context.Context, tx pgx.Tx, mtreeTable string, position i if err != nil { return fmt.Errorf("failed to render DeleteBlock SQL: %w", err) } - _, err = tx.Exec(ctx, query, position) + _, err = db.Exec(ctx, query, position) return err } -func UpdateNodePositionsSequentialTx(ctx context.Context, tx pgx.Tx, mtreeTable string, startPosition int64) error { +func UpdateNodePositionsSequential(ctx context.Context, db DBQuerier, mtreeTable string, startPosition int64) error { data := map[string]interface{}{ "MtreeTable": pgx.Identifier{mtreeTable}.Sanitize(), } @@ -2553,11 +2355,11 @@ func UpdateNodePositionsSequentialTx(ctx context.Context, tx pgx.Tx, mtreeTable if err != nil { return fmt.Errorf("failed to render UpdateNodePositionsSequential SQL: %w", err) } - _, err = tx.Exec(ctx, query, startPosition, startPosition) + _, err = db.Exec(ctx, query, startPosition, startPosition) return err } -func ResetPositionsByStartTx(ctx context.Context, tx pgx.Tx, mtreeTable string, key []string, isComposite bool) error { +func ResetPositionsByStart(ctx context.Context, db DBQuerier, mtreeTable string, key []string, isComposite bool) error { data := map[string]any{ "MtreeTable": pgx.Identifier{mtreeTable}.Sanitize(), } @@ -2571,51 +2373,13 @@ func ResetPositionsByStartTx(ctx context.Context, tx pgx.Tx, mtreeTable string, if err != nil { return fmt.Errorf("failed to render ResetPositionsByStart SQL: %w", err) } - if _, err := tx.Exec(ctx, query); err != nil { + if _, err := db.Exec(ctx, query); err != nil { return fmt.Errorf("query to reset positions failed: %w", err) } return nil } -func ComputeLeafHashesTx(ctx context.Context, tx pgx.Tx, schema, table string, cols []string, simpleKey bool, key []string, start []any, end []any) ([]byte, error) { - // Re-use BlockHashSQL logic. Ensure to pass tx to Scan, not db. - sql, err := BlockHashSQL(schema, table, key, "MTREE_LEAF_HASH") - if err != nil { - return nil, err - } - - args := make([]any, 0, 2+len(start)+len(end)) - skipMin := len(start) == 0 || start[0] == nil - args = append(args, skipMin) - args = append(args, start...) - skipMax := len(end) == 0 || end[0] == nil - args = append(args, skipMax) - args = append(args, end...) - - var leafHash []byte - if err := tx.QueryRow(ctx, sql, args...).Scan(&leafHash); err != nil { - return nil, fmt.Errorf("query to compute leaf hashes for '%s.%s' failed: %w", schema, table, err) - } - return leafHash, nil -} - -func UpdateLeafHashesTx(ctx context.Context, tx pgx.Tx, mtreeTable string, leafHash []byte, nodePosition int64) (int64, error) { - data := map[string]interface{}{ - "MtreeTable": mtreeTable, - } - sql, err := RenderSQL(SQLTemplates.UpdateLeafHashes, data) - if err != nil { - return 0, fmt.Errorf("failed to render UpdateLeafHashes SQL: %w", err) - } - var updatedPos int64 - err = tx.QueryRow(ctx, sql, leafHash, nodePosition).Scan(&updatedPos) - if err != nil { - return 0, fmt.Errorf("query to update leaf hashes for '%s' failed: %w", mtreeTable, err) - } - return updatedPos, nil -} - -func UpdateNodePositionsTempTx(ctx context.Context, tx pgx.Tx, mtreeTable string, offset int64) error { +func UpdateNodePositionsTemp(ctx context.Context, db DBQuerier, mtreeTable string, offset int64) error { data := map[string]any{ "MtreeTable": pgx.Identifier{mtreeTable}.Sanitize(), } @@ -2623,13 +2387,13 @@ func UpdateNodePositionsTempTx(ctx context.Context, tx pgx.Tx, mtreeTable string if err != nil { return fmt.Errorf("failed to render UpdateNodePositionsTempTx SQL: %w", err) } - if _, err := tx.Exec(ctx, sql, offset); err != nil { + if _, err := db.Exec(ctx, sql, offset); err != nil { return fmt.Errorf("query to update node positions to temp failed: %w", err) } return nil } -func GetBulkSplitPointsTx(ctx context.Context, tx pgx.Tx, schema, table string, key []string, isComposite bool, start, end []any, blockSize int) ([][]any, error) { +func GetBulkSplitPoints(ctx context.Context, db DBQuerier, schema, table string, key []string, pkeyType string, isComposite bool, start, end []any, blockSize int) ([][]any, error) { args := []any{} paramIndex := 1 @@ -2692,7 +2456,9 @@ func GetBulkSplitPointsTx(ctx context.Context, tx pgx.Tx, schema, table string, return nil, fmt.Errorf("failed to render GetBulkSplitPoints SQL: %w", err) } - rows, err := tx.Query(ctx, query, args...) + fmt.Println(query) + + rows, err := db.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to execute bulk split points query: %w", err) } @@ -2718,7 +2484,7 @@ func GetBulkSplitPointsTx(ctx context.Context, tx pgx.Tx, schema, table string, return splitPoints, nil } -func CreatePublication(ctx context.Context, db *pgxpool.Pool, publicationName string) error { +func CreatePublication(ctx context.Context, db DBQuerier, publicationName string) error { data := map[string]interface{}{ "PublicationName": publicationName, } @@ -2735,7 +2501,7 @@ func CreatePublication(ctx context.Context, db *pgxpool.Pool, publicationName st return nil } -func CreateReplicationSlot(ctx context.Context, db *pgxpool.Pool, slotName string) error { +func CreateReplicationSlot(ctx context.Context, db DBQuerier, slotName string) error { data := map[string]interface{}{ "SlotName": slotName, } @@ -2752,7 +2518,7 @@ func CreateReplicationSlot(ctx context.Context, db *pgxpool.Pool, slotName strin return nil } -func UpdateCDCMetadata(ctx context.Context, db *pgxpool.Pool, publicationName, slotName, startLSN string, tables []string) error { +func UpdateCDCMetadata(ctx context.Context, db DBQuerier, publicationName, slotName, startLSN string, tables []string) error { sql, err := RenderSQL(SQLTemplates.UpdateCDCMetadata, nil) if err != nil { return fmt.Errorf("failed to render UpdateCDCMetadata SQL: %w", err) @@ -2766,7 +2532,7 @@ func UpdateCDCMetadata(ctx context.Context, db *pgxpool.Pool, publicationName, s return nil } -func AlterPublicationAddTable(ctx context.Context, db *pgxpool.Pool, publicationName, tableName string) error { +func AlterPublicationAddTable(ctx context.Context, db DBQuerier, publicationName, tableName string) error { data := map[string]interface{}{ "PublicationName": publicationName, "TableName": tableName, @@ -2784,7 +2550,7 @@ func AlterPublicationAddTable(ctx context.Context, db *pgxpool.Pool, publication return nil } -func MarkBlockDirty(ctx context.Context, db *pgxpool.Pool, mtreeTable, pkeyValue string) error { +func MarkBlockDirty(ctx context.Context, db DBQuerier, mtreeTable, pkeyValue string) error { data := map[string]interface{}{ "MtreeTable": mtreeTable, "PkeyValue": pkeyValue, @@ -2802,7 +2568,7 @@ func MarkBlockDirty(ctx context.Context, db *pgxpool.Pool, mtreeTable, pkeyValue return nil } -func DropPublication(ctx context.Context, db *pgxpool.Pool, publicationName string) error { +func DropPublication(ctx context.Context, db DBQuerier, publicationName string) error { data := map[string]interface{}{ "PublicationName": publicationName, } @@ -2819,7 +2585,7 @@ func DropPublication(ctx context.Context, db *pgxpool.Pool, publicationName stri return nil } -func DropReplicationSlot(ctx context.Context, db *pgxpool.Pool, slotName string) error { +func DropReplicationSlot(ctx context.Context, db DBQuerier, slotName string) error { data := map[string]interface{}{ "SlotName": slotName, } @@ -2836,7 +2602,7 @@ func DropReplicationSlot(ctx context.Context, db *pgxpool.Pool, slotName string) return nil } -func DropCDCMetadataTable(ctx context.Context, db *pgxpool.Pool) error { +func DropCDCMetadataTable(ctx context.Context, db DBQuerier) error { sql, err := RenderSQL(SQLTemplates.DropCDCMetadataTable, nil) if err != nil { return fmt.Errorf("failed to render DropCDCMetadataTable SQL: %w", err) @@ -2850,7 +2616,7 @@ func DropCDCMetadataTable(ctx context.Context, db *pgxpool.Pool) error { return nil } -func GetCDCMetadata(ctx context.Context, db *pgxpool.Pool, publicationName string) (string, string, []string, error) { +func GetCDCMetadata(ctx context.Context, db DBQuerier, publicationName string) (string, string, []string, error) { sql, err := RenderSQL(SQLTemplates.GetCDCMetadata, nil) if err != nil { return "", "", nil, err @@ -2864,7 +2630,7 @@ func GetCDCMetadata(ctx context.Context, db *pgxpool.Pool, publicationName strin return slotName, startLSN, tables, nil } -func UpdateMtreeCounters(ctx context.Context, db *pgxpool.Pool, mtreeTable string, isComposite bool, compositeTypeName string, inserts, deletes, updates []string) error { +func UpdateMtreeCounters(ctx context.Context, db DBQuerier, mtreeTable string, isComposite bool, compositeTypeName string, inserts, deletes, updates []string) error { sql, err := RenderSQL(SQLTemplates.UpdateMtreeCounters, struct { MtreeTable string IsComposite bool diff --git a/db/queries/templates.go b/db/queries/templates.go index 44e3e26..c1af3ad 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -220,7 +220,7 @@ var SQLTemplates = Templates{ {{if .IsComposite}} p.pkey::{{.CompositeTypeName}} < (SELECT range_start FROM first_block) {{else}} - p.pkey < (SELECT range_start FROM first_block) + p.pkey::{{.PkeyType}} < (SELECT range_start FROM first_block) {{end}} ) ), @@ -236,14 +236,14 @@ var SQLTemplates = Templates{ {{if .IsComposite}} p.pkey::{{.CompositeTypeName}} >= mt.range_start AND (mt.range_end IS NULL OR p.pkey::{{.CompositeTypeName}} <= mt.range_end) {{else}} - p.pkey >= mt.range_start AND (mt.range_end IS NULL OR p.pkey <= mt.range_end) + p.pkey::{{.PkeyType}} >= mt.range_start AND (mt.range_end IS NULL OR p.pkey::{{.PkeyType}} <= mt.range_end) {{end}} ) OR ( mt.node_position = (SELECT node_position FROM first_block) AND {{if .IsComposite}} p.pkey::{{.CompositeTypeName}} < (SELECT range_start FROM first_block) {{else}} - p.pkey < (SELECT range_start FROM first_block) + p.pkey::{{.PkeyType}} < (SELECT range_start FROM first_block) {{end}} ) WHERE diff --git a/internal/cdc/listen.go b/internal/cdc/listen.go index f6fabf9..fd6271b 100644 --- a/internal/cdc/listen.go +++ b/internal/cdc/listen.go @@ -57,9 +57,24 @@ func processReplicationStream(nodeInfo map[string]any, continuous bool) { cfg := config.Cfg.MTree.CDC publication := cfg.PublicationName slotName := cfg.SlotName - _, startLSNStr, _, err := queries.GetCDCMetadata(context.Background(), pool, publication) - if err != nil { - logger.Error("failed to get cdc metadata: %v", err) + var startLSNStr string + func() { + tx, err := pool.Begin(context.Background()) + if err != nil { + logger.Error("failed to begin transaction: %v", err) + return + } + defer tx.Rollback(context.Background()) + _, startLSNStr, _, err = queries.GetCDCMetadata(context.Background(), tx, publication) + if err != nil { + logger.Error("failed to get cdc metadata: %v", err) + startLSNStr = "" + } + }() + + if startLSNStr == "" { + logger.Error("Could not retrieve LSN. Aborting CDC processing for this node.") + return } startLSN, err := pglogrepl.ParseLSN(startLSNStr) @@ -268,6 +283,13 @@ var quotedOIDs = map[uint32]bool{ } func processChanges(pool *pgxpool.Pool, changes []cdcMsg) { + tx, err := pool.Begin(context.Background()) + if err != nil { + logger.Error("failed to begin transaction for processing changes: %v", err) + return + } + defer tx.Rollback(context.Background()) + tables := make(map[string][]cdcMsg) for _, change := range changes { key := fmt.Sprintf("%s.%s", change.schema, change.table) @@ -307,12 +329,16 @@ func processChanges(pool *pgxpool.Pool, changes []cdcMsg) { } } - err := queries.UpdateMtreeCounters(context.Background(), pool, mtreeTable, isComposite, compositeTypeName, inserts, deletes, updates) + err := queries.UpdateMtreeCounters(context.Background(), tx, mtreeTable, isComposite, compositeTypeName, inserts, deletes, updates) if err != nil { logger.Error("failed to update mtree counters for %s: %v", mtreeTable, err) + return } } } + if err := tx.Commit(context.Background()); err != nil { + logger.Error("failed to commit CDC changes: %v", err) + } } func getPKs(rel *pglogrepl.RelationMessage, tup *pglogrepl.TupleData) []string { diff --git a/internal/core/merkle_trees.go b/internal/core/merkle_trees.go index 489a0ed..f49e489 100644 --- a/internal/core/merkle_trees.go +++ b/internal/core/merkle_trees.go @@ -321,12 +321,18 @@ func (m *MerkleTreeTask) MtreeInit() error { } defer pool.Close() - err = queries.CreateXORFunction(context.Background(), pool) + tx, err := pool.Begin(context.Background()) + if err != nil { + return fmt.Errorf("failed to begin transaction on node %s: %w", nodeInfo["Name"], err) + } + defer tx.Rollback(context.Background()) + + err = queries.CreateXORFunction(context.Background(), tx) if err != nil { return fmt.Errorf("failed to create xor function: %w", err) } - err = queries.CreateCDCMetadataTable(context.Background(), pool) + err = queries.CreateCDCMetadataTable(context.Background(), tx) if err != nil { return fmt.Errorf("failed to create cdc metadata table: %w", err) } @@ -336,11 +342,14 @@ func (m *MerkleTreeTask) MtreeInit() error { return fmt.Errorf("failed to setup replication: %w", err) } - err = queries.UpdateCDCMetadata(context.Background(), pool, cfg.PublicationName, cfg.SlotName, lsn.String(), []string{}) + err = queries.UpdateCDCMetadata(context.Background(), tx, cfg.PublicationName, cfg.SlotName, lsn.String(), []string{}) if err != nil { return fmt.Errorf("failed to update cdc metadata: %w", err) } + if err := tx.Commit(context.Background()); err != nil { + return fmt.Errorf("failed to commit transaction on node %s: %w", nodeInfo["Name"], err) + } logger.Info("Merkle tree objects initialised on node: %s", nodeInfo["Name"]) } return nil @@ -536,12 +545,17 @@ func (m *MerkleTreeTask) RunChecks(skipValidation bool) error { if _, err := pool.Exec(context.Background(), "CREATE EXTENSION IF NOT EXISTS pgcrypto;"); err != nil { return fmt.Errorf("failed to ensure pgcrypto is installed on %s: %w", nodeInfo["Name"], err) } + tx, err := pool.Begin(context.Background()) + if err != nil { + return fmt.Errorf("failed to begin transaction for checks on node %s: %w", nodeInfo["Name"], err) + } + defer tx.Rollback(context.Background()) - currentColsSlice, err := queries.GetColumns(context.Background(), pool, m.Schema, m.Table) + currentColsSlice, err := queries.GetColumns(context.Background(), tx, m.Schema, m.Table) if err != nil { return fmt.Errorf("failed to get columns on node %s: %w", nodeInfo["Name"], err) } - currentKeySlice, err := queries.GetPrimaryKey(context.Background(), pool, m.Schema, m.Table) + currentKeySlice, err := queries.GetPrimaryKey(context.Background(), tx, m.Schema, m.Table) if err != nil { return fmt.Errorf("failed to get primary key on node %s: %w", nodeInfo["Name"], err) } @@ -667,8 +681,14 @@ func (m *MerkleTreeTask) BuildMtree() error { return fmt.Errorf("failed to connect to node %s for mtree build: %w", nodeInfo["Name"], err) } + tx, err := pool.Begin(context.Background()) + if err != nil { + return fmt.Errorf("failed to begin transaction on node %s: %w", nodeInfo["Name"], err) + } + defer tx.Rollback(context.Background()) + publicationName := cfg.PublicationName - err = queries.AlterPublicationAddTable(context.Background(), pool, publicationName, m.QualifiedTableName) + err = queries.AlterPublicationAddTable(context.Background(), tx, publicationName, m.QualifiedTableName) if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == tableAlreadyInPublicationError { @@ -681,7 +701,7 @@ func (m *MerkleTreeTask) BuildMtree() error { logger.Info("Added table %s to publication %s on node %s", m.QualifiedTableName, publicationName, nodeInfo["Name"]) } - slotName, startLSN, tables, err := queries.GetCDCMetadata(context.Background(), pool, publicationName) + slotName, startLSN, tables, err := queries.GetCDCMetadata(context.Background(), tx, publicationName) if err != nil { pool.Close() return fmt.Errorf("failed to get cdc metadata on node %s: %w", nodeInfo["Name"], err) @@ -699,7 +719,7 @@ func (m *MerkleTreeTask) BuildMtree() error { tables = append(tables, m.QualifiedTableName) } - err = queries.UpdateCDCMetadata(context.Background(), pool, publicationName, slotName, startLSN, tables) + err = queries.UpdateCDCMetadata(context.Background(), tx, publicationName, slotName, startLSN, tables) if err != nil { pool.Close() return fmt.Errorf("failed to update cdc metadata on node %s: %w", nodeInfo["Name"], err) @@ -707,34 +727,35 @@ func (m *MerkleTreeTask) BuildMtree() error { logger.Info("Updated CDC metadata for table %s on node %s", m.QualifiedTableName, nodeInfo["Name"]) logger.Info("Creating Merkle Tree objects on %s...", nodeInfo["Name"]) - err = m.createMtreeObjects(pool, maxRows, numBlocks) + err = m.createMtreeObjects(tx, maxRows, numBlocks) if err != nil { pool.Close() return fmt.Errorf("failed to create mtree objects on node %s: %w", nodeInfo["Name"], err) } logger.Info("Inserting block ranges on %s...", nodeInfo["Name"]) - err = m.insertBlockRanges(pool, blockRanges) + err = m.insertBlockRanges(tx, blockRanges) if err != nil { pool.Close() return fmt.Errorf("failed to insert block ranges on node %s: %w", nodeInfo["Name"], err) } logger.Info("Computing leaf hashes on %s...", nodeInfo["Name"]) - err = m.computeLeafHashes(pool, blockRanges) + err = m.computeLeafHashes(pool, tx, blockRanges) if err != nil { pool.Close() return fmt.Errorf("failed to compute leaf hashes on node %s: %w", nodeInfo["Name"], err) } logger.Info("Building parent nodes on %s...", nodeInfo["Name"]) - err = m.buildParentNodes(pool) + err = m.buildParentNodes(tx) if err != nil { pool.Close() return fmt.Errorf("failed to build parent nodes on node %s: %w", nodeInfo["Name"], err) } logger.Info("Merkle tree built successfully on %s", nodeInfo["Name"]) + tx.Commit(context.Background()) pool.Close() } @@ -807,7 +828,7 @@ func (m *MerkleTreeTask) UpdateMtree(skipAllChecks bool) error { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table) - blocksToUpdate, err := queries.GetDirtyAndNewBlocksTx(context.Background(), tx, mtreeTableName, m.SimplePrimaryKey, m.Key) + blocksToUpdate, err := queries.GetDirtyAndNewBlocks(context.Background(), tx, mtreeTableName, m.SimplePrimaryKey, m.Key) if err != nil { return fmt.Errorf("error getting dirty blocks on node %s: %w", nodeInfo["Name"], err) } @@ -824,7 +845,7 @@ func (m *MerkleTreeTask) UpdateMtree(skipAllChecks bool) error { blockPositionsToSplit = append(blockPositionsToSplit, b.NodePosition) } - blocksToSplit, err := queries.FindBlocksToSplitTx(context.Background(), tx, mtreeTableName, splitThreshold, blockPositionsToSplit, m.SimplePrimaryKey, m.Key) + blocksToSplit, err := queries.FindBlocksToSplit(context.Background(), tx, mtreeTableName, splitThreshold, blockPositionsToSplit, m.SimplePrimaryKey, m.Key) if err != nil { return fmt.Errorf("query to find blocks to split for '%s' failed: %w", mtreeTableName, err) } @@ -843,7 +864,7 @@ func (m *MerkleTreeTask) UpdateMtree(skipAllChecks bool) error { } } - blocksToUpdate, err = queries.GetDirtyAndNewBlocksTx(context.Background(), tx, mtreeTableName, m.SimplePrimaryKey, m.Key) + blocksToUpdate, err = queries.GetDirtyAndNewBlocks(context.Background(), tx, mtreeTableName, m.SimplePrimaryKey, m.Key) if err != nil { return err } @@ -876,11 +897,11 @@ func (m *MerkleTreeTask) UpdateMtree(skipAllChecks bool) error { ) for _, block := range blocksToUpdate { - leafHash, err := queries.ComputeLeafHashesTx(context.Background(), tx, m.Schema, m.Table, m.Cols, m.SimplePrimaryKey, m.Key, block.RangeStart, block.RangeEnd) + leafHash, err := queries.ComputeLeafHashes(context.Background(), tx, m.Schema, m.Table, m.SimplePrimaryKey, m.Key, block.RangeStart, block.RangeEnd) if err != nil { return fmt.Errorf("failed to recompute hash for block %d: %w", block.NodePosition, err) } - if _, err := queries.UpdateLeafHashesTx(context.Background(), tx, mtreeTableName, leafHash, block.NodePosition); err != nil { + if _, err := queries.UpdateLeafHashes(context.Background(), tx, mtreeTableName, leafHash, block.NodePosition); err != nil { return fmt.Errorf("failed to update leaf hash for block %d: %w", block.NodePosition, err) } bar.Increment() @@ -893,7 +914,7 @@ func (m *MerkleTreeTask) UpdateMtree(skipAllChecks bool) error { } fmt.Println("Clearing dirty flags for affected blocks") - err = queries.ClearDirtyFlagsTx(context.Background(), tx, mtreeTableName, affectedPositions) + err = queries.ClearDirtyFlags(context.Background(), tx, mtreeTableName, affectedPositions) if err != nil { return err } @@ -919,7 +940,7 @@ func (m *MerkleTreeTask) splitBlocks(tx pgx.Tx, blocksToSplit []types.BlockRange currentBlocks := make([]types.BlockRange, len(blocksToSplit)) copy(currentBlocks, blocksToSplit) - if err := queries.DeleteParentNodesTx(ctx, tx, mtreeTableName); err != nil { + if err := queries.DeleteParentNodes(ctx, tx, mtreeTableName); err != nil { return nil, fmt.Errorf("failed to delete parent nodes: %w", err) } @@ -933,10 +954,10 @@ func (m *MerkleTreeTask) splitBlocks(tx pgx.Tx, blocksToSplit []types.BlockRange var maxVal []any var err error if isComposite { - maxVal, err = queries.GetMaxValCompositeTx(ctx, tx, m.Schema, m.Table, m.Key, start) + maxVal, err = queries.GetMaxValComposite(ctx, tx, m.Schema, m.Table, m.Key, start) } else { var simpleMaxVal any - simpleMaxVal, err = queries.GetMaxValSimpleTx(ctx, tx, m.Schema, m.Table, m.Key[0], start[0]) + simpleMaxVal, err = queries.GetMaxValSimple(ctx, tx, m.Schema, m.Table, m.Key[0], start[0]) if err == nil && simpleMaxVal != nil { maxVal = []any{simpleMaxVal} } @@ -955,7 +976,11 @@ func (m *MerkleTreeTask) splitBlocks(tx pgx.Tx, blocksToSplit []types.BlockRange continue } - splitPoints, err := queries.GetBulkSplitPointsTx(ctx, tx, m.Schema, m.Table, m.Key, isComposite, start, end, m.BlockSize) + pkeyType, err := queries.GetPkeyType(ctx, tx, m.Schema, m.Table, m.Key[0]) + if err != nil { + return nil, fmt.Errorf("failed to get pkey type: %w", err) + } + splitPoints, err := queries.GetBulkSplitPoints(ctx, tx, m.Schema, m.Table, m.Key, pkeyType, isComposite, start, end, m.BlockSize) if err != nil { return nil, fmt.Errorf("failed to get bulk split points for block %d: %w", pos, err) } @@ -978,22 +1003,22 @@ func (m *MerkleTreeTask) splitBlocks(tx pgx.Tx, blocksToSplit []types.BlockRange for _, sp := range splitPoints { if isComposite { - err = queries.UpdateBlockRangeEndCompositeTx(ctx, tx, mtreeTableName, compositeTypeName, sp, pos) + err = queries.UpdateBlockRangeEndComposite(ctx, tx, mtreeTableName, compositeTypeName, sp, pos) } else { - err = queries.UpdateBlockRangeEndTx(ctx, tx, mtreeTableName, sp[0], pos) + err = queries.UpdateBlockRangeEnd(ctx, tx, mtreeTableName, sp[0], pos) } if err != nil { return nil, err } - newPos, err := queries.GetMaxNodePositionTx(ctx, tx, mtreeTableName) + newPos, err := queries.GetMaxNodePosition(ctx, tx, mtreeTableName) if err != nil { return nil, err } if isComposite { - err = queries.InsertCompositeBlockRangesTx(ctx, tx, mtreeTableName, newPos, sp, nil) + err = queries.InsertCompositeBlockRanges(ctx, tx, mtreeTableName, newPos, sp, nil) } else { - err = queries.InsertBlockRangesTx(ctx, tx, mtreeTableName, newPos, sp[0], nil) + err = queries.InsertBlockRanges(ctx, tx, mtreeTableName, newPos, sp[0], nil) } if err != nil { return nil, err @@ -1004,15 +1029,15 @@ func (m *MerkleTreeTask) splitBlocks(tx pgx.Tx, blocksToSplit []types.BlockRange if originallyUnbounded { if isComposite { - err = queries.UpdateBlockRangeEndCompositeTx(ctx, tx, mtreeTableName, compositeTypeName, nil, pos) + err = queries.UpdateBlockRangeEndComposite(ctx, tx, mtreeTableName, compositeTypeName, nil, pos) } else { - err = queries.UpdateBlockRangeEndTx(ctx, tx, mtreeTableName, nil, pos) + err = queries.UpdateBlockRangeEnd(ctx, tx, mtreeTableName, nil, pos) } } else { if isComposite { - err = queries.UpdateBlockRangeEndCompositeTx(ctx, tx, mtreeTableName, compositeTypeName, end, pos) + err = queries.UpdateBlockRangeEndComposite(ctx, tx, mtreeTableName, compositeTypeName, end, pos) } else { - err = queries.UpdateBlockRangeEndTx(ctx, tx, mtreeTableName, end[0], pos) + err = queries.UpdateBlockRangeEnd(ctx, tx, mtreeTableName, end[0], pos) } } if err != nil { @@ -1029,7 +1054,7 @@ func (m *MerkleTreeTask) performMerges(tx pgx.Tx) ([]int64, error) { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table) for { - blocksToMerge, err := queries.FindBlocksToMergeTx(context.Background(), tx, mtreeTableName, m.SimplePrimaryKey, m.Schema, m.Table, m.Key, 0.25, []int64{}) + blocksToMerge, err := queries.FindBlocksToMerge(context.Background(), tx, mtreeTableName, m.SimplePrimaryKey, m.Schema, m.Table, m.Key, 0.25, []int64{}) if err != nil { return nil, fmt.Errorf("query to find blocks to merge for '%s' failed: %w", mtreeTableName, err) } @@ -1051,7 +1076,7 @@ func (m *MerkleTreeTask) performMerges(tx pgx.Tx) ([]int64, error) { allModifiedPositions = append(allModifiedPositions, modified...) // Reset positions after each pass to ensure pos+1 logic works correctly in the next iteration. - if err := queries.ResetPositionsByStartTx(context.Background(), tx, mtreeTableName, m.Key, !m.SimplePrimaryKey); err != nil { + if err := queries.ResetPositionsByStart(context.Background(), tx, mtreeTableName, m.Key, !m.SimplePrimaryKey); err != nil { return nil, fmt.Errorf("failed to reset positions after merges: %w", err) } } @@ -1335,7 +1360,7 @@ func (m *MerkleTreeTask) mergeBlocks(tx pgx.Tx, blocksToMerge []types.BlockRange compositeTypeName := fmt.Sprintf("%s_%s_key_type", m.Schema, m.Table) - if err := queries.DeleteParentNodesTx(ctx, tx, mtreeTableName); err != nil { + if err := queries.DeleteParentNodes(ctx, tx, mtreeTableName); err != nil { return nil, fmt.Errorf("failed to delete parent nodes: %w", err) } @@ -1348,7 +1373,7 @@ func (m *MerkleTreeTask) mergeBlocks(tx pgx.Tx, blocksToMerge []types.BlockRange continue } - currentBlock, err := queries.GetBlockWithCountTx(ctx, tx, mtreeTableName, m.Schema, m.Table, m.Key, isComposite, pos) + currentBlock, err := queries.GetBlockWithCount(ctx, tx, mtreeTableName, m.Schema, m.Table, m.Key, isComposite, pos) if err != nil { return nil, fmt.Errorf("failed to get current block %d with count: %w", pos, err) } @@ -1357,22 +1382,22 @@ func (m *MerkleTreeTask) mergeBlocks(tx pgx.Tx, blocksToMerge []types.BlockRange } // Attempt to merge with the next block - nextBlock, err := queries.GetBlockWithCountTx(ctx, tx, mtreeTableName, m.Schema, m.Table, m.Key, isComposite, pos+1) + nextBlock, err := queries.GetBlockWithCount(ctx, tx, mtreeTableName, m.Schema, m.Table, m.Key, isComposite, pos+1) if err != nil { return nil, fmt.Errorf("failed to get next block for %d: %w", pos, err) } if nextBlock != nil && (currentBlock.Count+nextBlock.Count < int64(float64(m.BlockSize)*1.5)) { if isComposite { - err = queries.UpdateBlockRangeEndCompositeTx(ctx, tx, mtreeTableName, compositeTypeName, nextBlock.RangeEnd, currentBlock.NodePosition) + err = queries.UpdateBlockRangeEndComposite(ctx, tx, mtreeTableName, compositeTypeName, nextBlock.RangeEnd, currentBlock.NodePosition) } else { - err = queries.UpdateBlockRangeEndTx(ctx, tx, mtreeTableName, valueOrNil(nextBlock.RangeEnd), currentBlock.NodePosition) + err = queries.UpdateBlockRangeEnd(ctx, tx, mtreeTableName, valueOrNil(nextBlock.RangeEnd), currentBlock.NodePosition) } if err != nil { return nil, err } - if err := queries.DeleteBlockTx(ctx, tx, mtreeTableName, nextBlock.NodePosition); err != nil { + if err := queries.DeleteBlock(ctx, tx, mtreeTableName, nextBlock.NodePosition); err != nil { return nil, err } modifiedPositions = append(modifiedPositions, currentBlock.NodePosition) @@ -1382,12 +1407,12 @@ func (m *MerkleTreeTask) mergeBlocks(tx pgx.Tx, blocksToMerge []types.BlockRange return modifiedPositions, nil } -func (m *MerkleTreeTask) buildParentNodes(conn queries.DBTX) error { +func (m *MerkleTreeTask) buildParentNodes(conn queries.DBQuerier) error { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table) var err error if tx, ok := conn.(pgx.Tx); ok { - err = queries.DeleteParentNodesTx(context.Background(), tx, mtreeTableName) + err = queries.DeleteParentNodes(context.Background(), tx, mtreeTableName) } else if pool, ok := conn.(*pgxpool.Pool); ok { err = queries.DeleteParentNodes(context.Background(), pool, mtreeTableName) } else { @@ -1403,7 +1428,7 @@ func (m *MerkleTreeTask) buildParentNodes(conn queries.DBTX) error { var count int var buildErr error if tx, ok := conn.(pgx.Tx); ok { - count, buildErr = queries.BuildParentNodesTx(context.Background(), tx, mtreeTableName, level) + count, buildErr = queries.BuildParentNodes(context.Background(), tx, mtreeTableName, level) } else if pool, ok := conn.(*pgxpool.Pool); ok { count, buildErr = queries.BuildParentNodes(context.Background(), pool, mtreeTableName, level) } else { @@ -1428,7 +1453,8 @@ type LeafHashResult struct { Err error } -func (m *MerkleTreeTask) computeLeafHashes(pool *pgxpool.Pool, ranges []types.BlockRange) error { +func (m *MerkleTreeTask) computeLeafHashes(pool *pgxpool.Pool, tx pgx.Tx, ranges []types.BlockRange) error { + mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table) numWorkers := int(float64(runtime.NumCPU()) * m.MaxCpuRatio) if numWorkers < 1 { @@ -1475,36 +1501,13 @@ func (m *MerkleTreeTask) computeLeafHashes(pool *pgxpool.Pool, ranges []types.Bl leafHashes[result.BlockID] = result.Hash } - mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table) - - batch := &pgx.Batch{} for blockID, hash := range leafHashes { - _, err := queries.UpdateLeafHashes(context.Background(), pool, mtreeTableName, hash, blockID) + _, err := queries.UpdateLeafHashes(context.Background(), tx, mtreeTableName, hash, blockID) if err != nil { return err } } - - tx, err := pool.Begin(context.Background()) - if err != nil { - return err - } - defer tx.Rollback(context.Background()) - - br := tx.SendBatch(context.Background(), batch) - defer br.Close() - - for i := 0; i < batch.Len(); i++ { - if _, err := br.Exec(); err != nil { - return fmt.Errorf("failed to execute batch update for leaf hashes: %w", err) - } - } - - if err := br.Close(); err != nil { - return fmt.Errorf("failed to close batch: %w", err) - } - - return tx.Commit(context.Background()) + return nil } func (m *MerkleTreeTask) leafHashWorker(wg *sync.WaitGroup, jobs <-chan types.BlockRange, results chan<- LeafHashResult, pool *pgxpool.Pool, bar *mpb.Bar) { @@ -1522,16 +1525,16 @@ func (m *MerkleTreeTask) leafHashWorker(wg *sync.WaitGroup, jobs <-chan types.Bl } } -func (m *MerkleTreeTask) insertBlockRanges(pool *pgxpool.Pool, ranges []types.BlockRange) error { +func (m *MerkleTreeTask) insertBlockRanges(conn queries.DBQuerier, ranges []types.BlockRange) error { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table) mtreeTableIdent := pgx.Identifier{mtreeTableName} if m.SimplePrimaryKey { - if err := queries.InsertBlockRangesBatchSimple(context.Background(), pool, mtreeTableIdent.Sanitize(), ranges); err != nil { + if err := queries.InsertBlockRangesBatchSimple(context.Background(), conn, mtreeTableIdent.Sanitize(), ranges); err != nil { return err } } else { - if err := queries.InsertBlockRangesBatchComposite(context.Background(), pool, mtreeTableIdent.Sanitize(), ranges, len(m.Key)); err != nil { + if err := queries.InsertBlockRangesBatchComposite(context.Background(), conn, mtreeTableIdent.Sanitize(), ranges, len(m.Key)); err != nil { return err } } @@ -1539,47 +1542,42 @@ func (m *MerkleTreeTask) insertBlockRanges(pool *pgxpool.Pool, ranges []types.Bl return nil } -func (m *MerkleTreeTask) createMtreeObjects(pool *pgxpool.Pool, totalRows int64, numBlocks int) error { - tx, err := pool.Begin(context.Background()) - if err != nil { - return err - } - defer tx.Rollback(context.Background()) +func (m *MerkleTreeTask) createMtreeObjects(tx pgx.Tx, totalRows int64, numBlocks int) error { - err = queries.CreateXORFunction(context.Background(), pool) + err := queries.CreateXORFunction(context.Background(), tx) if err != nil { return fmt.Errorf("failed to create xor function: %w", err) } - err = queries.CreateMetadataTable(context.Background(), pool) + err = queries.CreateMetadataTable(context.Background(), tx) if err != nil { return fmt.Errorf("failed to create metadata table: %w", err) } - err = queries.UpdateMetadata(context.Background(), pool, m.Schema, m.Table, totalRows, m.BlockSize, numBlocks, !m.SimplePrimaryKey) + err = queries.UpdateMetadata(context.Background(), tx, m.Schema, m.Table, totalRows, m.BlockSize, numBlocks, !m.SimplePrimaryKey) if err != nil { return fmt.Errorf("failed to update metadata: %w", err) } mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table) - err = queries.DropMtreeTable(context.Background(), pool, mtreeTableName) + err = queries.DropMtreeTable(context.Background(), tx, mtreeTableName) if err != nil { return fmt.Errorf("failed to render drop mtree table sql: %w", err) } if m.SimplePrimaryKey { - pkeyType, err := queries.GetPkeyType(context.Background(), pool, m.Schema, m.Table, m.Key[0]) + pkeyType, err := queries.GetPkeyType(context.Background(), tx, m.Schema, m.Table, m.Key[0]) if err != nil { return err } - err = queries.CreateSimpleMtreeTable(context.Background(), pool, mtreeTableName, pkeyType) + err = queries.CreateSimpleMtreeTable(context.Background(), tx, mtreeTableName, pkeyType) if err != nil { return fmt.Errorf("failed to render create simple mtree table sql: %w", err) } } else { keyTypeColumns := make([]string, len(m.Key)) for i, col := range m.Key { - colType, err := queries.GetPkeyType(context.Background(), pool, m.Schema, m.Table, col) + colType, err := queries.GetPkeyType(context.Background(), tx, m.Schema, m.Table, col) if err != nil { return err } @@ -1588,27 +1586,27 @@ func (m *MerkleTreeTask) createMtreeObjects(pool *pgxpool.Pool, totalRows int64, compositeTypeName := fmt.Sprintf("%s_%s_key_type", m.Schema, m.Table) - err = queries.DropCompositeType(context.Background(), pool, compositeTypeName) + err = queries.DropCompositeType(context.Background(), tx, compositeTypeName) if err != nil { return fmt.Errorf("failed to render drop composite type sql: %w", err) } - err = queries.CreateCompositeType(context.Background(), pool, compositeTypeName, strings.Join(keyTypeColumns, ", ")) + err = queries.CreateCompositeType(context.Background(), tx, compositeTypeName, strings.Join(keyTypeColumns, ", ")) if err != nil { return fmt.Errorf("failed to render create composite type sql: %w", err) } - err = queries.CreateCompositeMtreeTable(context.Background(), pool, mtreeTableName, compositeTypeName) + err = queries.CreateCompositeMtreeTable(context.Background(), tx, mtreeTableName, compositeTypeName) if err != nil { return fmt.Errorf("failed to render create composite mtree table sql: %w", err) } } - err = m.buildParentNodes(pool) + err = m.buildParentNodes(tx) if err != nil { return err } - return tx.Commit(context.Background()) + return nil } func computeSamplingParameters(rowCount int64) (string, float64) { diff --git a/internal/core/schema_diff.go b/internal/core/schema_diff.go index 11f29f7..4123855 100644 --- a/internal/core/schema_diff.go +++ b/internal/core/schema_diff.go @@ -143,9 +143,7 @@ func (c *SchemaDiffCmd) RunChecks(skipValidation bool) error { } c.tableList = []string{} - for _, table := range tables { - c.tableList = append(c.tableList, table) - } + c.tableList = append(c.tableList, tables...) if len(c.tableList) == 0 { return fmt.Errorf("no tables found in schema %s", c.SchemaName) @@ -309,36 +307,28 @@ func getObjectsForSchema(pool *pgxpool.Pool, schemaName string) (*SchemaObjects, return nil, fmt.Errorf("could not query tables: %w", err) } var tableNames []string - for _, t := range tables { - tableNames = append(tableNames, t) - } + tableNames = append(tableNames, tables...) views, err := queries.GetViewsInSchema(context.Background(), pool, schemaName) if err != nil { return nil, fmt.Errorf("could not query views: %w", err) } var viewNames []string - for _, v := range views { - viewNames = append(viewNames, v) - } + viewNames = append(viewNames, views...) functions, err := queries.GetFunctionsInSchema(context.Background(), pool, schemaName) if err != nil { return nil, fmt.Errorf("could not query functions: %w", err) } var functionSignatures []string - for _, f := range functions { - functionSignatures = append(functionSignatures, f) - } + functionSignatures = append(functionSignatures, functions...) indices, err := queries.GetIndicesInSchema(context.Background(), pool, schemaName) if err != nil { return nil, fmt.Errorf("could not query indices: %w", err) } var indexNames []string - for _, i := range indices { - indexNames = append(indexNames, i) - } + indexNames = append(indexNames, indices...) return &SchemaObjects{ Tables: tableNames,