diff --git a/db/helpers/sqlhelpers.go b/db/helpers/sqlhelpers.go index eb03ea1..6efc476 100644 --- a/db/helpers/sqlhelpers.go +++ b/db/helpers/sqlhelpers.go @@ -105,9 +105,9 @@ func AvgColumnSize(ctx context.Context, pool *pgxpool.Pool, schema, table, colum return 0, err } - schemaIdent := fmt.Sprintf("\"%s\"", schema) - tableIdent := fmt.Sprintf("\"%s\"", table) - colIdent := fmt.Sprintf("\"%s\"", column) + schemaIdent := fmt.Sprintf(`"%s"`, schema) + tableIdent := fmt.Sprintf(`"%s"`, table) + colIdent := fmt.Sprintf(`"%s"`, column) query := fmt.Sprintf( `SELECT COALESCE(AVG(pg_column_size(%s)), 0) FROM %s.%s`, @@ -129,7 +129,7 @@ func GeneratePkeyOffsetsQuery( samplePercent float64, ntileCount int, ) (string, error) { - for _, ident := range append(keyColumns, schema, table) { + for _, ident := range append([]string{schema, table}, keyColumns...) { if err := SanitiseIdentifier(ident); err != nil { return "", fmt.Errorf("invalid identifier %q: %w", ident, err) } @@ -137,36 +137,60 @@ func GeneratePkeyOffsetsQuery( schemaIdent := fmt.Sprintf(`"%s"`, schema) tableIdent := fmt.Sprintf(`"%s"`, table) - keyColsSelect := strings.Join(keyColumns, ",\n ") - keyColsOrder := strings.Join(keyColumns, ", ") + quotedKeyColsOriginal := make([]string, len(keyColumns)) + for i, c := range keyColumns { + quotedKeyColsOriginal[i] = fmt.Sprintf(`"%s"`, c) + } + + keyColsSelect := strings.Join(quotedKeyColsOriginal, ",\n ") + keyColsOrder := strings.Join(quotedKeyColsOriginal, ", ") var descs []string for _, c := range keyColumns { - descs = append(descs, fmt.Sprintf("%s DESC", c)) + descs = append(descs, fmt.Sprintf(`"%s" DESC`, c)) } keyColsOrderDesc := strings.Join(descs, ", ") var firstSelects, lastSelects, firstTuples []string for _, c := range keyColumns { + quotedCol := fmt.Sprintf(`"%s"`, c) firstSelects = append(firstSelects, - fmt.Sprintf("(SELECT %s FROM first_row) AS %s", c, c)) + fmt.Sprintf(`(SELECT %s FROM first_row) AS %s`, quotedCol, quotedCol)) lastSelects = append(lastSelects, - fmt.Sprintf("(SELECT %s FROM last_row) AS %s", c, c)) + fmt.Sprintf(`(SELECT %s FROM last_row) AS %s`, quotedCol, quotedCol)) firstTuples = append(firstTuples, - fmt.Sprintf("(SELECT %s FROM first_row)", c)) + fmt.Sprintf(`(SELECT %s FROM first_row)`, quotedCol)) } - var rangeStarts, rangeEnds, rangeOutputs []string + var rangeStarts, rangeEnds []string for _, c := range keyColumns { - rangeStarts = append(rangeStarts, fmt.Sprintf("%s AS range_start_%s", c, c)) + quotedCol := fmt.Sprintf(`"%s"`, c) + aliasStart := fmt.Sprintf(`range_start_%s`, c) + quotedAliasStart := fmt.Sprintf(`"%s"`, aliasStart) + + aliasEnd := fmt.Sprintf(`range_end_%s`, c) + quotedAliasEnd := fmt.Sprintf(`"%s"`, aliasEnd) + + rangeStarts = append(rangeStarts, fmt.Sprintf(`%s AS %s`, quotedCol, quotedAliasStart)) rangeEnds = append(rangeEnds, fmt.Sprintf( - "LEAD(%s) OVER (ORDER BY seq, %s) AS range_end_%s", - c, keyColsOrder, c, + `LEAD(%s) OVER (ORDER BY seq, %s) AS %s`, + quotedCol, keyColsOrder, quotedAliasEnd, )) - rangeOutputs = append(rangeOutputs, - fmt.Sprintf("range_start_%s,\n range_end_%s", c, c)) } + var startComponentCols []string + var endComponentCols []string + for _, c := range keyColumns { + aliasStart := fmt.Sprintf(`range_start_%s`, c) + quotedAliasStart := fmt.Sprintf(`"%s"`, aliasStart) + startComponentCols = append(startComponentCols, quotedAliasStart) + + aliasEnd := fmt.Sprintf(`range_end_%s`, c) + quotedAliasEnd := fmt.Sprintf(`"%s"`, aliasEnd) + endComponentCols = append(endComponentCols, quotedAliasEnd) + } + selectOutputCols := append(startComponentCols, endComponentCols...) + data := map[string]any{ "SchemaIdent": schemaIdent, "TableIdent": tableIdent, @@ -181,7 +205,7 @@ func GeneratePkeyOffsetsQuery( "FirstRowTupleSelects": strings.Join(firstTuples, ",\n "), "RangeStartColumns": strings.Join(rangeStarts, ",\n "), "RangeEndColumns": strings.Join(rangeEnds, ",\n "), - "RangeOutputColumns": strings.Join(rangeOutputs, ",\n "), + "RangeOutputColumns": strings.Join(selectOutputCols, ",\n "), } var buf bytes.Buffer @@ -191,35 +215,73 @@ func GeneratePkeyOffsetsQuery( return buf.String(), nil } -func BlockHashSQL(schema, table string, cols []string, primaryKey string) (string, error) { +func BlockHashSQL(schema, table string, cols []string, primaryKeyCols []string) (string, error) { if err := SanitiseIdentifier(schema); err != nil { return "", err } if err := SanitiseIdentifier(table); err != nil { return "", err } - if err := SanitiseIdentifier(primaryKey); err != nil { - return "", err - } - for _, col := range cols { - if err := SanitiseIdentifier(col); err != nil { - return "", err + for _, pkCol := range primaryKeyCols { + if err := SanitiseIdentifier(pkCol); err != nil { + return "", fmt.Errorf("invalid primary key column identifier %q: %w", pkCol, err) } } - schemaIdent := fmt.Sprintf("\"%s\"", schema) - tableIdent := fmt.Sprintf("\"%s\"", table) - primaryIdent := fmt.Sprintf("\"%s\"", primaryKey) - var colIdents []string - for _, col := range cols { - colIdents = append(colIdents, fmt.Sprintf("\"%s\"", col)) + + schemaIdent := fmt.Sprintf(`"%s"`, schema) + tableIdent := fmt.Sprintf(`"%s"`, table) + tableAlias := "_tbl_" + + quotedPKColIdents := make([]string, len(primaryKeyCols)) + for i, pkCol := range primaryKeyCols { + quotedPKColIdents[i] = fmt.Sprintf(`"%s"`, pkCol) + } + pkOrderByStr := strings.Join(quotedPKColIdents, ", ") + + pkComparisonExpression := "" + if len(primaryKeyCols) == 1 { + pkComparisonExpression = quotedPKColIdents[0] + } else { + pkComparisonExpression = fmt.Sprintf("ROW(%s)", strings.Join(quotedPKColIdents, ", ")) + } + + startPlaceholders := make([]string, len(primaryKeyCols)) + for i := range primaryKeyCols { + startPlaceholders[i] = fmt.Sprintf("$%d", 2+i) } - colsList := strings.Join(colIdents, ", ") + startValueExpression := "" + if len(primaryKeyCols) == 1 { + startValueExpression = startPlaceholders[0] + } else { + startValueExpression = fmt.Sprintf("ROW(%s)", strings.Join(startPlaceholders, ", ")) + } + + skipMaxCheckPlaceholderIndex := 2 + len(primaryKeyCols) + + endPlaceholders := make([]string, len(primaryKeyCols)) + for i := range primaryKeyCols { + endPlaceholders[i] = fmt.Sprintf("$%d", skipMaxCheckPlaceholderIndex+1+i) + } + endValueExpression := "" + if len(primaryKeyCols) == 1 { + endValueExpression = endPlaceholders[0] + } else { + endValueExpression = fmt.Sprintf("ROW(%s)", strings.Join(endPlaceholders, ", ")) + } + query := fmt.Sprintf( - `SELECT encode(digest(COALESCE(string_agg(concat_ws('|', %s),'|' ORDER BY %s),'EMPTY_BLOCK'),'sha1'),'hex') - FROM %s.%s - WHERE ($1::boolean OR %s >= $2) - AND ($3::boolean OR %s < $4)`, - colsList, primaryIdent, schemaIdent, tableIdent, primaryIdent, primaryIdent, + `SELECT encode(digest(COALESCE(string_agg(%s::text, '|' ORDER BY %s), '[EMPTY_BLOCK]'), 'sha1'), 'hex') + FROM %s.%s AS %s + WHERE ($1::boolean OR %s >= %s) + AND ($%d::boolean OR %s < %s)`, + tableAlias, + pkOrderByStr, + schemaIdent, tableIdent, tableAlias, + pkComparisonExpression, + startValueExpression, + skipMaxCheckPlaceholderIndex, + pkComparisonExpression, + endValueExpression, ) return query, nil } diff --git a/internal/core/table_diff.go b/internal/core/table_diff.go index df8ec0b..de2eb5f 100644 --- a/internal/core/table_diff.go +++ b/internal/core/table_diff.go @@ -223,14 +223,13 @@ func (t *TableDiffTask) fetchRows(ctx context.Context, nodeName string, r Range) } selectColsStr := strings.Join(quotedSelectCols, ", ") - pkColumn := sanitise(t.Key[0]) + quotedKeyCols := make([]string, len(t.Key)) + for i, k := range t.Key { + quotedKeyCols[i] = sanitise(k) + } orderByClause := "" if len(t.Key) > 0 { - quotedKeyCols := make([]string, len(t.Key)) - for i, k := range t.Key { - quotedKeyCols[i] = sanitise(k) - } orderByClause = "ORDER BY " + strings.Join(quotedKeyCols, ", ") } @@ -239,14 +238,57 @@ func (t *TableDiffTask) fetchRows(ctx context.Context, nodeName string, r Range) var conditions []string paramIndex := 1 + if r.Start != nil { - conditions = append(conditions, fmt.Sprintf("%s >= $%d", pkColumn, paramIndex)) - args = append(args, r.Start) - paramIndex++ + startVal := r.Start + if len(t.Key) == 1 { + // Simple primary key + conditions = append(conditions, fmt.Sprintf("%s >= $%d", quotedKeyCols[0], paramIndex)) + args = append(args, startVal) + paramIndex++ + } else { + // Composite primary key + startVals, ok := startVal.([]any) + if !ok || len(startVals) != len(t.Key) { + return nil, fmt.Errorf("r.Start is not a valid composite key for table %s.%s (expected %d values, got %T with value %v)", t.Schema, t.Table, len(t.Key), startVal, startVal) + } + + pkTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(quotedKeyCols, ", ")) + + placeholders := make([]string, len(t.Key)) + for i := 0; i < len(t.Key); i++ { + placeholders[i] = fmt.Sprintf("$%d", paramIndex+i) + } + placeholderTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(placeholders, ", ")) + + conditions = append(conditions, fmt.Sprintf("%s >= %s", pkTupleStr, placeholderTupleStr)) + args = append(args, startVals...) + paramIndex += len(t.Key) + } } + if r.End != nil { - conditions = append(conditions, fmt.Sprintf("%s <= $%d", pkColumn, paramIndex)) - args = append(args, r.End) + endVal := r.End + if len(t.Key) == 1 { + conditions = append(conditions, fmt.Sprintf("%s <= $%d", quotedKeyCols[0], paramIndex)) + args = append(args, endVal) + } else { + endVals, ok := endVal.([]any) + if !ok || len(endVals) != len(t.Key) { + return nil, fmt.Errorf("r.End is not a valid composite key for table %s.%s (expected %d values, got %T with value %v)", t.Schema, t.Table, len(t.Key), endVal, endVal) + } + + pkTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(quotedKeyCols, ", ")) + + placeholders := make([]string, len(t.Key)) + for i := 0; i < len(t.Key); i++ { + placeholders[i] = fmt.Sprintf("$%d", paramIndex+i) + } + placeholderTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(placeholders, ", ")) + + conditions = append(conditions, fmt.Sprintf("%s <= %s", pkTupleStr, placeholderTupleStr)) + args = append(args, endVals...) + } } whereClause := "" @@ -625,20 +667,22 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { user, strings.Join(missingPrivs, ", "), schema, table, hostname) } - // connParams = append(connParams, params) hostMap[hostIP+":"+fmt.Sprint(int(port))] = hostname if t.TableFilter != "" { viewName := fmt.Sprintf("%s_%s_filtered", t.TaskID, table) - viewSQL := fmt.Sprintf("CREATE VIEW %s AS SELECT * FROM %s.%s WHERE %s", - viewName, schema, table, t.TableFilter) + sanitisedViewName := sanitise(viewName) + sanitisedSchema := sanitise(schema) + sanitisedTable := sanitise(table) + viewSQL := fmt.Sprintf("CREATE OR REPLACE VIEW %s AS SELECT * FROM %s.%s WHERE %s", + sanitisedViewName, sanitisedSchema, sanitisedTable, t.TableFilter) _, err = conn.Exec(context.Background(), viewSQL) if err != nil { return fmt.Errorf("failed to create filtered view: %w", err) } - hasRowsSQL := fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM %s) AS has_rows", viewName) + hasRowsSQL := fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM %s) AS has_rows", sanitisedViewName) var hasRows bool err = conn.QueryRow(context.Background(), hasRowsSQL).Scan(&hasRows) if err != nil { @@ -798,7 +842,7 @@ func (t *TableDiffTask) ExecuteTask(debugMode bool) error { } t.Pools = pools - blockHashSQL, err := helpers.BlockHashSQL(t.Schema, t.Table, t.Cols, t.Key[0]) + blockHashSQL, err := helpers.BlockHashSQL(t.Schema, t.Table, t.Cols, t.Key) if err != nil { return fmt.Errorf("failed to build block-hash SQL: %w", err) } @@ -866,7 +910,12 @@ func (t *TableDiffTask) ExecuteTask(debugMode bool) error { } var ranges []Range - // Determine if we should use direct PKey offset generation + /* Determine if we should use direct PKey offset generation. + * Essentially, we don't want to use probabilistic sampling for tables with + * less than 10,000 rows to avoid non-deterministic results. + + * TODO: table-filter should also support probabilistic sampling. + */ if (maxCount > 0 && maxCount <= 10000) || t.TableFilter != "" { logger.Info("Using direct primary key offset generation for table %s.%s (maxCount: %d, tableFilter: '%s')", t.Schema, t.Table, maxCount, t.TableFilter) @@ -892,18 +941,42 @@ func (t *TableDiffTask) ExecuteTask(debugMode bool) error { } defer pkRangesRows.Close() + numPKCols := len(t.Key) + totalScanCols := 2 * numPKCols + if totalScanCols == 0 { + return fmt.Errorf("primary key not defined, cannot determine columns to scan for ranges") + } + scanDest := make([]any, totalScanCols) + scanDestPtrs := make([]any, totalScanCols) + for i := range scanDest { + scanDestPtrs[i] = &scanDest[i] + } + for pkRangesRows.Next() { - var startVal, endVal any - if err := pkRangesRows.Scan(&startVal, &endVal); err != nil { - return fmt.Errorf("scanning offset row failed: %w", err) + if err := pkRangesRows.Scan(scanDestPtrs...); err != nil { + return fmt.Errorf("scanning offset row failed (expected %d columns for %d PKs): %w", totalScanCols, numPKCols, err) } - ranges = append(ranges, Range{Start: startVal, End: endVal}) + + var rStart, rEnd any + + if numPKCols == 1 { + rStart = scanDest[0] + rEnd = scanDest[1] + } else { + startKeyParts := make([]any, numPKCols) + copy(startKeyParts, scanDest[0:numPKCols]) + rStart = startKeyParts + + endKeyParts := make([]any, numPKCols) + copy(endKeyParts, scanDest[numPKCols:2*numPKCols]) + rEnd = endKeyParts + } + ranges = append(ranges, Range{Start: rStart, End: rEnd}) } if err := pkRangesRows.Err(); err != nil { return fmt.Errorf("offset rows iteration error: %w", err) } - // Prepend (nil, first_original_start) only if ranges were actually generated and the first doesn't already start with nil. if len(ranges) > 0 && ranges[0].Start != nil { firstOriginalStart := ranges[0].Start newInitialRange := Range{Start: nil, End: firstOriginalStart} @@ -1062,23 +1135,56 @@ func (t *TableDiffTask) hashRange( } startTime := time.Now() var hash string - var skipMinCheck bool - var skipMaxCheck bool + + skipMinCheck := r.Start == nil + skipMaxCheck := r.End == nil + + numPKCols := len(t.Key) + sqlArgs := make([]any, 0, 2+2*numPKCols) + + sqlArgs = append(sqlArgs, skipMinCheck) if r.Start == nil { - skipMinCheck = true + for i := 0; i < numPKCols; i++ { + sqlArgs = append(sqlArgs, nil) + } + } else { + if numPKCols == 1 { + sqlArgs = append(sqlArgs, r.Start) + } else { + startVals, ok := r.Start.([]any) + if !ok || len(startVals) != numPKCols { + return "", fmt.Errorf("[%s] r.Start is not a valid composite key for hashing (expected %d values, got %T with value %v)", node, numPKCols, r.Start, r.Start) + } + sqlArgs = append(sqlArgs, startVals...) + } } + + sqlArgs = append(sqlArgs, skipMaxCheck) + if r.End == nil { - skipMaxCheck = true + for i := 0; i < numPKCols; i++ { + sqlArgs = append(sqlArgs, nil) + } + } else { + if numPKCols == 1 { + sqlArgs = append(sqlArgs, r.End) + } else { + endVals, ok := r.End.([]any) + if !ok || len(endVals) != numPKCols { + return "", fmt.Errorf("[%s] r.End is not a valid composite key for hashing (expected %d values, got %T with value %v)", node, numPKCols, r.End, r.End) + } + sqlArgs = append(sqlArgs, endVals...) + } } - logger.Debug("[%s] Hashing range: Start=%v, End=%v", node, r.Start, r.End) + logger.Debug("[%s] Hashing range: Start=%v, End=%v. SQL: %s, Args: %v", node, r.Start, r.End, t.BlockHashSQL, sqlArgs) - err := pool.QueryRow(ctx, t.BlockHashSQL, skipMinCheck, r.Start, skipMaxCheck, r.End).Scan(&hash) + err := pool.QueryRow(ctx, t.BlockHashSQL, sqlArgs...).Scan(&hash) if err != nil { duration := time.Since(startTime) - logger.Info("[%s] ERROR after %v for range Start=%v, End=%v (using query: '%s'): %v", node, duration, r.Start, r.End, t.BlockHashSQL, err) + logger.Info("[%s] ERROR after %v for range Start=%v, End=%v (using query: '%s', args: %v): %v", node, duration, r.Start, r.End, t.BlockHashSQL, sqlArgs, err) return "", fmt.Errorf("BlockHash query failed for %s range %v-%v: %w", node, r.Start, r.End, err) } @@ -1356,17 +1462,21 @@ func safeCut(s string, n int) string { } func (t *TableDiffTask) getPkeyOffsets(ctx context.Context, pool *pgxpool.Pool) ([]Range, error) { - // TODO: Add support for composite keys. if len(t.Key) == 0 { return nil, fmt.Errorf("primary key not defined for table %s.%s", t.Schema, t.Table) } - primaryKeyColumn := t.Key[0] schemaIdent := sanitise(t.Schema) tableIdent := sanitise(t.Table) - pkIdent := sanitise(primaryKeyColumn) - querySQL := fmt.Sprintf("SELECT %s FROM %s.%s ORDER BY %s", pkIdent, schemaIdent, tableIdent, pkIdent) + quotedKeyCols := make([]string, len(t.Key)) + for i, k := range t.Key { + quotedKeyCols[i] = sanitise(k) + } + // Using this for both select and order by to keep things simple + pkStr := strings.Join(quotedKeyCols, ", ") + + querySQL := fmt.Sprintf("SELECT %s FROM %s.%s ORDER BY %s", pkStr, schemaIdent, tableIdent, pkStr) pgRows, err := pool.Query(ctx, querySQL) if err != nil { @@ -1375,12 +1485,27 @@ func (t *TableDiffTask) getPkeyOffsets(ctx context.Context, pool *pgxpool.Pool) defer pgRows.Close() var allPks []any + numPKCols := len(t.Key) + scanDest := make([]any, numPKCols) + scanDestPtrs := make([]any, numPKCols) + for i := range scanDest { + scanDestPtrs[i] = &scanDest[i] + } + for pgRows.Next() { - var pkVal any - if err := pgRows.Scan(&pkVal); err != nil { + if err := pgRows.Scan(scanDestPtrs...); err != nil { return nil, fmt.Errorf("failed to scan primary key value from %s.%s: %w", t.Schema, t.Table, err) } - allPks = append(allPks, pkVal) + if numPKCols == 1 { + allPks = append(allPks, scanDest[0]) + } else { + /* PKs are composite, so create a new slice for each PK to avoid allPks + * elements pointing to the same underlying scanDest array. + */ + currentPkComposite := make([]any, numPKCols) + copy(currentPkComposite, scanDest) + allPks = append(allPks, currentPkComposite) + } } if err := pgRows.Err(); err != nil {