From 02812fd411f875c969707979206c189bce93224b Mon Sep 17 00:00:00 2001 From: Tej Kashi Date: Mon, 21 Jul 2025 16:46:52 -0400 Subject: [PATCH] Refactor common funcs into utils.go * Move logger into pkg/ --- cmd/server/main.go | 2 +- internal/cli/cli.go | 4 +- internal/core/repset_diff.go | 5 +- internal/core/schema_diff.go | 227 ++++++++++--------------- internal/core/spock_diff.go | 17 +- internal/core/table_diff.go | 227 ++++--------------------- internal/core/table_repair.go | 46 +++-- internal/core/table_rerun.go | 24 +-- {internal/core => pkg/common}/utils.go | 124 +++++++++++--- {internal => pkg}/logger/logger.go | 0 10 files changed, 272 insertions(+), 404 deletions(-) rename {internal/core => pkg/common}/utils.go (81%) rename {internal => pkg}/logger/logger.go (100%) diff --git a/cmd/server/main.go b/cmd/server/main.go index 16d68a3..2c82455 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -16,8 +16,8 @@ import ( "path/filepath" "github.com/pgedge/ace/internal/cli" - "github.com/pgedge/ace/internal/logger" "github.com/pgedge/ace/pkg/config" + "github.com/pgedge/ace/pkg/logger" ) func main() { diff --git a/internal/cli/cli.go b/internal/cli/cli.go index ec792a2..44e59ff 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -18,7 +18,7 @@ import ( "github.com/charmbracelet/log" "github.com/google/uuid" "github.com/pgedge/ace/internal/core" - "github.com/pgedge/ace/internal/logger" + "github.com/pgedge/ace/pkg/logger" "github.com/urfave/cli/v2" ) @@ -431,7 +431,7 @@ func SchemaDiffCLI(ctx *cli.Context) error { task.Output = ctx.String("output") task.OverrideBlockSize = ctx.Bool("override-block-size") - if err := core.SchemaDiff(task); err != nil { + if err := task.SchemaTableDiff(); err != nil { return fmt.Errorf("error during schema diff: %w", err) } return nil diff --git a/internal/core/repset_diff.go b/internal/core/repset_diff.go index 753e682..a4b5548 100644 --- a/internal/core/repset_diff.go +++ b/internal/core/repset_diff.go @@ -24,7 +24,8 @@ import ( "github.com/jackc/pgx/v4/pgxpool" "github.com/pgedge/ace/db/queries" "github.com/pgedge/ace/internal/auth" - "github.com/pgedge/ace/internal/logger" + utils "github.com/pgedge/ace/pkg/common" + "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" ) @@ -94,7 +95,7 @@ func (c *RepsetDiffCmd) RunChecks(skipValidation bool) error { return err } - if err := readClusterInfo(c); err != nil { + if err := utils.ReadClusterInfo(c); err != nil { return err } if len(c.clusterNodes) == 0 { diff --git a/internal/core/schema_diff.go b/internal/core/schema_diff.go index df9480f..924df4b 100644 --- a/internal/core/schema_diff.go +++ b/internal/core/schema_diff.go @@ -16,14 +16,14 @@ import ( "encoding/json" "fmt" "maps" - "sort" "strconv" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/pgxpool" "github.com/pgedge/ace/db/queries" "github.com/pgedge/ace/internal/auth" - "github.com/pgedge/ace/internal/logger" + utils "github.com/pgedge/ace/pkg/common" + "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" ) @@ -50,6 +50,32 @@ type SchemaDiffCmd struct { OverrideBlockSize bool } +type SchemaObjects struct { + Tables []string `json:"tables"` + Views []string `json:"views"` + Functions []string `json:"functions"` + Indices []string `json:"indices"` +} + +func (so SchemaObjects) IsEmpty() bool { + return len(so.Tables) == 0 && len(so.Views) == 0 && len(so.Functions) == 0 && len(so.Indices) == 0 +} + +type NodeSchemaReport struct { + NodeName string `json:"node_name"` + Objects SchemaObjects `json:"objects"` +} + +type NodeComparisonReport struct { + Status string `json:"status"` + Diffs map[string]NodeDiff `json:"diffs,omitempty"` +} + +type NodeDiff struct { + MissingObjects SchemaObjects `json:"missing_objects"` + ExtraObjects SchemaObjects `json:"extra_objects"` +} + func (c *SchemaDiffCmd) GetClusterName() string { return c.ClusterName } func (c *SchemaDiffCmd) GetDBName() string { return c.DBName } func (c *SchemaDiffCmd) SetDBName(name string) { c.DBName = name } @@ -60,25 +86,13 @@ func (c *SchemaDiffCmd) SetDatabase(db types.Database) { c.database = db } func (c *SchemaDiffCmd) GetClusterNodes() []map[string]any { return c.clusterNodes } func (c *SchemaDiffCmd) SetClusterNodes(cn []map[string]any) { c.clusterNodes = cn } -func (c *SchemaDiffCmd) getTablesInSchema(db *pgxpool.Pool) error { - rows, err := db.Query(context.Background(), - "SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_type = 'BASE TABLE'", - c.SchemaName) - if err != nil { - return fmt.Errorf("could not query tables in schema: %w", err) +func (c *SchemaDiffCmd) Validate() error { + if c.ClusterName == "" { + return fmt.Errorf("cluster name is required") } - defer rows.Close() - - var tables []string - for rows.Next() { - var tableName string - if err := rows.Scan(&tableName); err != nil { - return fmt.Errorf("could not scan table name: %w", err) - } - tables = append(tables, tableName) + if c.SchemaName == "" { + return fmt.Errorf("schema name is required") } - - c.tableList = tables return nil } @@ -89,7 +103,7 @@ func (c *SchemaDiffCmd) RunChecks(skipValidation bool) error { } } - if err := readClusterInfo(c); err != nil { + if err := utils.ReadClusterInfo(c); err != nil { return err } if len(c.clusterNodes) == 0 { @@ -143,119 +157,7 @@ func (c *SchemaDiffCmd) RunChecks(skipValidation bool) error { return nil } -func (c *SchemaDiffCmd) Validate() error { - if c.ClusterName == "" { - return fmt.Errorf("cluster name is required") - } - if c.SchemaName == "" { - return fmt.Errorf("schema name is required") - } - return nil -} - -type SchemaObjects struct { - Tables []string `json:"tables"` - Views []string `json:"views"` - Functions []string `json:"functions"` - Indices []string `json:"indices"` -} - -func (so SchemaObjects) IsEmpty() bool { - return len(so.Tables) == 0 && len(so.Views) == 0 && len(so.Functions) == 0 && len(so.Indices) == 0 -} - -type NodeSchemaReport struct { - NodeName string `json:"node_name"` - Objects SchemaObjects `json:"objects"` -} - -type NodeComparisonReport struct { - Status string `json:"status"` - Diffs map[string]NodeDiff `json:"diffs,omitempty"` -} - -type NodeDiff struct { - MissingObjects SchemaObjects `json:"missing_objects"` - ExtraObjects SchemaObjects `json:"extra_objects"` -} - -func diffStringSlices(a, b []string) (missing, extra []string) { - sort.Strings(a) - sort.Strings(b) - - aMap := make(map[string]struct{}, len(a)) - for _, s := range a { - aMap[s] = struct{}{} - } - - bMap := make(map[string]struct{}, len(b)) - for _, s := range b { - bMap[s] = struct{}{} - } - - for _, s := range a { - if _, found := bMap[s]; !found { - missing = append(missing, s) - } - } - - for _, s := range b { - if _, found := aMap[s]; !found { - extra = append(extra, s) - } - } - - return missing, extra -} - -func getObjectsForSchema(db queries.Querier, schemaName string) (*SchemaObjects, error) { - pgSchemaName := pgtype.Name{String: schemaName, Status: pgtype.Present} - - tables, err := db.GetTablesInSchema(context.Background(), pgSchemaName) - if err != nil { - return nil, fmt.Errorf("could not query tables: %w", err) - } - var tableNames []string - for _, t := range tables { - tableNames = append(tableNames, t.String) - } - - views, err := db.GetViewsInSchema(context.Background(), pgSchemaName) - if err != nil { - return nil, fmt.Errorf("could not query views: %w", err) - } - var viewNames []string - for _, v := range views { - viewNames = append(viewNames, v.String) - } - - functions, err := db.GetFunctionsInSchema(context.Background(), pgSchemaName) - if err != nil { - return nil, fmt.Errorf("could not query functions: %w", err) - } - var functionSignatures []string - for _, f := range functions { - functionSignatures = append(functionSignatures, *f) - } - - indices, err := db.GetIndicesInSchema(context.Background(), pgSchemaName) - if err != nil { - return nil, fmt.Errorf("could not query indices: %w", err) - } - var indexNames []string - for _, i := range indices { - indexNames = append(indexNames, i.String) - } - - return &SchemaObjects{ - Tables: tableNames, - Views: viewNames, - Functions: functionSignatures, - Indices: indexNames, - }, nil -} - -func schemaDDLDiff(task *SchemaDiffCmd) error { +func (task *SchemaDiffCmd) schemaObjectDiff() error { var allNodeObjects []NodeSchemaReport for _, nodeInfo := range task.clusterNodes { @@ -305,10 +207,10 @@ func schemaDDLDiff(task *SchemaDiffCmd) error { refObjects := referenceNode.Objects cmpObjects := compareNode.Objects - missingTables, extraTables := diffStringSlices(refObjects.Tables, cmpObjects.Tables) - missingViews, extraViews := diffStringSlices(refObjects.Views, cmpObjects.Views) - missingFunctions, extraFunctions := diffStringSlices(refObjects.Functions, cmpObjects.Functions) - missingIndices, extraIndices := diffStringSlices(refObjects.Indices, cmpObjects.Indices) + missingTables, extraTables := utils.DiffStringSlices(refObjects.Tables, cmpObjects.Tables) + missingViews, extraViews := utils.DiffStringSlices(refObjects.Views, cmpObjects.Views) + missingFunctions, extraFunctions := utils.DiffStringSlices(refObjects.Functions, cmpObjects.Functions) + missingIndices, extraIndices := utils.DiffStringSlices(refObjects.Indices, cmpObjects.Indices) refExtraObjects := SchemaObjects{ Tables: missingTables, @@ -358,13 +260,13 @@ func schemaDDLDiff(task *SchemaDiffCmd) error { return nil } -func SchemaDiff(task *SchemaDiffCmd) error { +func (task *SchemaDiffCmd) SchemaTableDiff() error { if err := task.RunChecks(false); err != nil { return err } if task.DDLOnly { - return schemaDDLDiff(task) + return task.schemaObjectDiff() } for _, tableName := range task.tableList { @@ -404,3 +306,50 @@ func SchemaDiff(task *SchemaDiffCmd) error { return nil } + +func getObjectsForSchema(db queries.Querier, schemaName string) (*SchemaObjects, error) { + pgSchemaName := pgtype.Name{String: schemaName, Status: pgtype.Present} + + tables, err := db.GetTablesInSchema(context.Background(), pgSchemaName) + if err != nil { + return nil, fmt.Errorf("could not query tables: %w", err) + } + var tableNames []string + for _, t := range tables { + tableNames = append(tableNames, t.String) + } + + views, err := db.GetViewsInSchema(context.Background(), pgSchemaName) + if err != nil { + return nil, fmt.Errorf("could not query views: %w", err) + } + var viewNames []string + for _, v := range views { + viewNames = append(viewNames, v.String) + } + + functions, err := db.GetFunctionsInSchema(context.Background(), pgSchemaName) + if err != nil { + return nil, fmt.Errorf("could not query functions: %w", err) + } + var functionSignatures []string + for _, f := range functions { + functionSignatures = append(functionSignatures, *f) + } + + indices, err := db.GetIndicesInSchema(context.Background(), pgSchemaName) + if err != nil { + return nil, fmt.Errorf("could not query indices: %w", err) + } + var indexNames []string + for _, i := range indices { + indexNames = append(indexNames, i.String) + } + + return &SchemaObjects{ + Tables: tableNames, + Views: viewNames, + Functions: functionSignatures, + Indices: indexNames, + }, nil +} diff --git a/internal/core/spock_diff.go b/internal/core/spock_diff.go index 595006e..b06ac3c 100644 --- a/internal/core/spock_diff.go +++ b/internal/core/spock_diff.go @@ -24,7 +24,8 @@ import ( "github.com/jackc/pgx/v4/pgxpool" "github.com/pgedge/ace/db/queries" "github.com/pgedge/ace/internal/auth" - "github.com/pgedge/ace/internal/logger" + utils "github.com/pgedge/ace/pkg/common" + "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" ) @@ -81,7 +82,7 @@ func (t *SpockDiffTask) Validate() error { return fmt.Errorf("cluster_name is a required argument") } - nodeList, err := ParseNodes(t.Nodes) + nodeList, err := utils.ParseNodes(t.Nodes) if err != nil { return fmt.Errorf("nodes should be a comma-separated list of nodenames. E.g., nodes=\"n1,n2\". Error: %w", err) } @@ -91,7 +92,7 @@ func (t *SpockDiffTask) Validate() error { return fmt.Errorf("spock-diff needs at least two nodes to compare") } - err = readClusterInfo(t) + err = utils.ReadClusterInfo(t) if err != nil { return fmt.Errorf("error loading cluster information: %w", err) } @@ -102,7 +103,7 @@ func (t *SpockDiffTask) Validate() error { for _, nodeMap := range t.ClusterNodes { if len(nodeList) > 0 { nameVal, _ := nodeMap["Name"].(string) - if !Contains(nodeList, nameVal) { + if !utils.Contains(nodeList, nameVal) { continue } } @@ -160,7 +161,7 @@ func (t *SpockDiffTask) RunChecks(skipValidation bool) error { port = "5432" } - if !Contains(t.NodeList, hostname) { + if !utils.Contains(t.NodeList, hostname) { continue } @@ -228,7 +229,7 @@ func (t *SpockDiffTask) ExecuteTask() error { sub.ReplicationSets = ni.SubReplicationSets if len(ni.SubReplicationSets) == 0 { hint := fmt.Sprintf("Subscription '%s' has no replication sets.", sub.SubName) - if !Contains(config.Hints, hint) { + if !utils.Contains(config.Hints, hint) { config.Hints = append(config.Hints, hint) } } @@ -321,10 +322,10 @@ func (t *SpockDiffTask) ExecuteTask() error { if !diff.Mismatch { diff.Message = fmt.Sprintf("Replication rules are the same for %s and %s", refNodeName, compareNodeName) - fmt.Printf("%s No differences found.\n", CheckMark) + fmt.Printf("%s No differences found.\n", utils.CheckMark) } else { diff.Message = fmt.Sprintf("Difference in Replication Rules between %s and %s", refNodeName, compareNodeName) - fmt.Printf("%s Differences found:\n", CrossMark) + fmt.Printf("%s Differences found:\n", utils.CrossMark) printDiffDetails(diff.Details, refNodeName, compareNodeName) } t.DiffResult.Diffs[pairKey] = diff diff --git a/internal/core/table_diff.go b/internal/core/table_diff.go index 97ae249..29f44d4 100644 --- a/internal/core/table_diff.go +++ b/internal/core/table_diff.go @@ -32,8 +32,9 @@ import ( "github.com/pgedge/ace/db/helpers" "github.com/pgedge/ace/db/queries" "github.com/pgedge/ace/internal/auth" - "github.com/pgedge/ace/internal/logger" + utils "github.com/pgedge/ace/pkg/common" "github.com/pgedge/ace/pkg/config" + "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" @@ -133,60 +134,6 @@ func NewTableDiffTask() *TableDiffTask { } } -func sanitise(identifier string) string { - return pgx.Identifier{identifier}.Sanitize() -} - -func StringifyKey(pkValues map[string]any, pkCols []string) (string, error) { - if len(pkCols) == 0 { - return "", nil - } - if len(pkValues) != len(pkCols) { - return "", fmt.Errorf("mismatch between pk value count (%d) and pk column count (%d)", len(pkValues), len(pkCols)) - } - - if len(pkCols) == 1 { - val, ok := pkValues[pkCols[0]] - if !ok { - return "", fmt.Errorf("pk column '%s' not found in pk values map", pkCols[0]) - } - return fmt.Sprintf("%v", val), nil - } - - sortedPkCols := make([]string, len(pkCols)) - copy(sortedPkCols, pkCols) - sort.Strings(sortedPkCols) - - var parts []string - for _, col := range sortedPkCols { - val, ok := pkValues[col] - if !ok { - return "", fmt.Errorf("pk column '%s' not found in pk values map", col) - } - parts = append(parts, fmt.Sprintf("%v", val)) - } - return strings.Join(parts, "||"), nil -} - -func isKnownScalarType(colType string) bool { - knownPrefixes := []string{ - "character", "text", - "integer", "bigint", "smallint", - "numeric", "decimal", "real", "double precision", - "boolean", - "bytea", - "json", "jsonb", - "uuid", - "timestamp", "date", "time", - } - for _, prefix := range knownPrefixes { - if strings.HasPrefix(colType, prefix) { - return true - } - } - return false -} - func (t *TableDiffTask) fetchRows(ctx context.Context, nodeName string, r Range) ([]map[string]any, error) { pool, ok := t.Pools[nodeName] if !ok { @@ -197,8 +144,8 @@ func (t *TableDiffTask) fetchRows(ctx context.Context, nodeName string, r Range) return nil, fmt.Errorf("primary key not defined for table %s.%s", t.Schema, t.Table) } - quotedSchema := sanitise(t.Schema) - quotedTable := sanitise(t.Table) + quotedSchema := pgx.Identifier{t.Schema}.Sanitize() + quotedTable := pgx.Identifier{t.Table}.Sanitize() quotedSchemaTable := fmt.Sprintf("%s.%s", quotedSchema, quotedTable) var colTypes map[string]string @@ -231,10 +178,10 @@ func (t *TableDiffTask) fetchRows(ctx context.Context, nodeName string, r Range) for _, colName := range t.Cols { colType := colTypes[colName] - quotedColName := sanitise(colName) + quotedColName := pgx.Identifier{colName}.Sanitize() // We cast user defined types and arrays to TEXT to avoid scan errors with unknown OIDs - if strings.HasSuffix(colType, "[]") || !isKnownScalarType(colType) { + if strings.HasSuffix(colType, "[]") || !utils.IsKnownScalarType(colType) { selectCols = append(selectCols, fmt.Sprintf("%s::TEXT AS %s", quotedColName, quotedColName)) } else { selectCols = append(selectCols, quotedColName) @@ -245,7 +192,7 @@ func (t *TableDiffTask) fetchRows(ctx context.Context, nodeName string, r Range) quotedKeyCols := make([]string, len(t.Key)) for i, k := range t.Key { - quotedKeyCols[i] = sanitise(k) + quotedKeyCols[i] = pgx.Identifier{k}.Sanitize() } orderByClause := "" @@ -428,7 +375,7 @@ func (t *TableDiffTask) compareBlocks( for _, pkCol := range t.Key { pkVal[pkCol] = row[pkCol] } - pkStr, err := StringifyKey(pkVal, t.Key) + pkStr, err := utils.StringifyKey(pkVal, t.Key) if err != nil { return nil, fmt.Errorf("failed to stringify n1 pkey: %w", err) } @@ -441,7 +388,7 @@ func (t *TableDiffTask) compareBlocks( for _, pkCol := range t.Key { pkVal[pkCol] = row[pkCol] } - pkStr, err := StringifyKey(pkVal, t.Key) + pkStr, err := utils.StringifyKey(pkVal, t.Key) if err != nil { return nil, fmt.Errorf("failed to stringify n2 pkey: %w", err) } @@ -521,7 +468,7 @@ func (t *TableDiffTask) Validate() error { return fmt.Errorf("table-diff currently supports only csv, json and html output formats") } - nodeList, err := ParseNodes(t.Nodes) + nodeList, err := utils.ParseNodes(t.Nodes) if err != nil { return fmt.Errorf("nodes should be a comma-separated list of nodenames. E.g., nodes=\"n1,n2\". Error: %w", err) } @@ -536,7 +483,7 @@ func (t *TableDiffTask) Validate() error { return fmt.Errorf("table-diff needs at least two nodes to compare") } - err = readClusterInfo(t) + err = utils.ReadClusterInfo(t) if err != nil { return fmt.Errorf("error loading cluster information: %w", err) } @@ -566,7 +513,7 @@ func (t *TableDiffTask) Validate() error { for _, nodeMap := range t.ClusterNodes { if len(nodeList) > 0 { nameVal, _ := nodeMap["Name"].(string) - if !Contains(nodeList, nameVal) { + if !utils.Contains(nodeList, nameVal) { continue } } @@ -631,7 +578,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { port = "5432" } - if !Contains(t.NodeList, hostname) { + if !utils.Contains(t.NodeList, hostname) { continue } @@ -641,7 +588,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { } defer conn.Close() - currCols, err := GetColumns(conn, schema, table) + currCols, err := utils.GetColumns(conn, schema, table) if err != nil { return fmt.Errorf("failed to get columns for table %s.%s on node %s: %w", schema, table, hostname, err) } @@ -649,7 +596,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { return fmt.Errorf("table '%s.%s' not found on %s, or the current user does not have adequate privileges", schema, table, hostname) } - currKey, err := GetPrimaryKey(conn, schema, table) + currKey, err := utils.GetPrimaryKey(conn, schema, table) if err != nil { return fmt.Errorf("failed to get primary key for table %s.%s on node %s: %w", schema, table, hostname, err) } @@ -669,7 +616,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { cols = currCols key = currKey - colTypes, err := GetColumnTypes(conn, table) + colTypes, err := utils.GetColumnTypes(conn, table) if err != nil { return fmt.Errorf("failed to get column types for table %s on node %s: %w", table, hostname, err) } @@ -681,7 +628,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { } t.ColTypes[colTypesKey] = colTypes - authorized, missingPrivileges, err := CheckUserPrivileges(conn, user, schema, table, requiredPrivileges) + authorized, missingPrivileges, err := utils.CheckUserPrivileges(conn, user, schema, table, requiredPrivileges) if err != nil { return fmt.Errorf("failed to check user privileges on node %s: %w", hostname, err) } @@ -704,9 +651,9 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { if t.TableFilter != "" { viewName := fmt.Sprintf("%s_%s_filtered", t.TaskID, table) - sanitisedViewName := sanitise(viewName) - sanitisedSchema := sanitise(schema) - sanitisedTable := sanitise(table) + sanitisedViewName := pgx.Identifier{viewName}.Sanitize() + sanitisedSchema := pgx.Identifier{schema}.Sanitize() + sanitisedTable := pgx.Identifier{table}.Sanitize() viewSQL := fmt.Sprintf("CREATE MATERIALIZED VIEW IF NOT EXISTS %s AS SELECT * FROM %s.%s WHERE %s", sanitisedViewName, sanitisedSchema, sanitisedTable, t.TableFilter) @@ -901,8 +848,8 @@ func (t *TableDiffTask) ExecuteTask() error { continue } } else { - sanitisedSchema := sanitise(t.Schema) - sanitisedTable := sanitise(t.Table) + sanitisedSchema := pgx.Identifier{t.Schema}.Sanitize() + sanitisedTable := pgx.Identifier{t.Table}.Sanitize() countQuerySQL := fmt.Sprintf("SELECT COUNT(*) FROM %s.%s", sanitisedSchema, sanitisedTable) logger.Debug("[%s] Executing count query for filtered table: %s", name, countQuerySQL) err = pool.QueryRow(ctx, countQuerySQL).Scan(&count) @@ -1122,7 +1069,7 @@ func (t *TableDiffTask) ExecuteTask() error { if r1.hash != r2.hash { logger.Debug("%s Mismatch in initial range %d (%v-%v) for %s vs %s. Hashes: %s... / %s... narrowing down diffs...", - CrossMark, rangeIdx, currentRange.Start, currentRange.End, node1, node2, safeCut(r1.hash, 8), safeCut(r2.hash, 8)) + utils.CrossMark, rangeIdx, currentRange.Start, currentRange.End, node1, node2, utils.SafeCut(r1.hash, 8), utils.SafeCut(r2.hash, 8)) mismatchedTasks = append(mismatchedTasks, RecursiveDiffTask{ Node1Name: node1, Node2Name: node2, @@ -1131,7 +1078,7 @@ func (t *TableDiffTask) ExecuteTask() error { }) } else { - logger.Debug("%s Match in initial range %d (%v-%v) for %s vs %s", CheckMark, rangeIdx, currentRange.Start, currentRange.End, node1, node2) + logger.Debug("%s Match in initial range %d (%v-%v) for %s vs %s", utils.CheckMark, rangeIdx, currentRange.Start, currentRange.End, node1, node2) } } } @@ -1190,7 +1137,7 @@ func (t *TableDiffTask) ExecuteTask() error { logger.Info("ERROR writing diff output to file %s: %v", outputFileName, err) return fmt.Errorf("failed to write diffs file: %w", err) } - logger.Warn("%s TABLES DO NOT MATCH", CrossMark) + logger.Warn("%s TABLES DO NOT MATCH", utils.CrossMark) for key, diffCount := range t.DiffResult.Summary.DiffRowsCount { logger.Warn("Found %d differences between %s", diffCount, key) @@ -1199,7 +1146,7 @@ func (t *TableDiffTask) ExecuteTask() error { logger.Info("Diff report written to %s", outputFileName) } else { - logger.Info("%s TABLES MATCH", CheckMark) + logger.Info("%s TABLES MATCH", utils.CheckMark) } return nil @@ -1297,11 +1244,11 @@ func (t *TableDiffTask) generateSubRanges( quotedKeyCols := make([]string, len(t.Key)) for i, k := range t.Key { - quotedKeyCols[i] = sanitise(k) + quotedKeyCols[i] = pgx.Identifier{k}.Sanitize() } pkColsStr := strings.Join(quotedKeyCols, ", ") pkTupleStr := fmt.Sprintf("ROW(%s)", pkColsStr) - schemaTable := fmt.Sprintf("%s.%s", sanitise(t.Schema), sanitise(t.Table)) + schemaTable := fmt.Sprintf("%s.%s", pgx.Identifier{t.Schema}.Sanitize(), pgx.Identifier{t.Table}.Sanitize()) var conditions []string args := []any{} @@ -1435,23 +1382,6 @@ func (t *TableDiffTask) generateSubRanges( return []Range{parentRange}, nil } -func addSpockMetadata(row map[string]any) map[string]any { - if row == nil { - return nil - } - metadata := make(map[string]any) - if commitTs, ok := row["commit_ts"]; ok { - metadata["commit_ts"] = commitTs - delete(row, "commit_ts") - } - if nodeOrigin, ok := row["node_origin"]; ok { - metadata["node_origin"] = nodeOrigin - delete(row, "node_origin") - } - row["_spock_metadata_"] = metadata - return row -} - func (t *TableDiffTask) recursiveDiff( ctx context.Context, task RecursiveDiffTask, @@ -1504,16 +1434,16 @@ func (t *TableDiffTask) recursiveDiff( } for _, row := range diffInfo.Node1OnlyRows { - t.DiffResult.NodeDiffs[pairKey].Rows[node1Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node1Name], addSpockMetadata(row)) + t.DiffResult.NodeDiffs[pairKey].Rows[node1Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node1Name], utils.AddSpockMetadata(row)) currentDiffRowsForPair++ } for _, row := range diffInfo.Node2OnlyRows { - t.DiffResult.NodeDiffs[pairKey].Rows[node2Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node2Name], addSpockMetadata(row)) + t.DiffResult.NodeDiffs[pairKey].Rows[node2Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node2Name], utils.AddSpockMetadata(row)) currentDiffRowsForPair++ } for _, modRow := range diffInfo.ModifiedRows { - t.DiffResult.NodeDiffs[pairKey].Rows[node1Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node1Name], addSpockMetadata(modRow.Node1Data)) - t.DiffResult.NodeDiffs[pairKey].Rows[node2Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node2Name], addSpockMetadata(modRow.Node2Data)) + t.DiffResult.NodeDiffs[pairKey].Rows[node1Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node1Name], utils.AddSpockMetadata(modRow.Node1Data)) + t.DiffResult.NodeDiffs[pairKey].Rows[node2Name] = append(t.DiffResult.NodeDiffs[pairKey].Rows[node2Name], utils.AddSpockMetadata(modRow.Node2Data)) currentDiffRowsForPair++ } @@ -1593,7 +1523,7 @@ func (t *TableDiffTask) recursiveDiff( if res1.hash != res2.hash { logger.Debug("%s Mismatch in sub-range %v-%v for %s (%s...) vs %s (%s...). Recursing.", - CrossMark, sr.Start, sr.End, node1Name, safeCut(res1.hash, 8), node2Name, safeCut(res2.hash, 8)) + utils.CrossMark, sr.Start, sr.End, node1Name, utils.SafeCut(res1.hash, 8), node2Name, utils.SafeCut(res2.hash, 8)) wg.Add(1) go t.recursiveDiff(ctx, RecursiveDiffTask{ Node1Name: node1Name, @@ -1602,100 +1532,11 @@ func (t *TableDiffTask) recursiveDiff( CurrentEstimatedBlockSize: newEstimatedBlockSize, }, wg) } else { - logger.Debug("%s Match in sub-range %v-%v for %s vs %s.", CheckMark, sr.Start, sr.End, node1Name, node2Name) + logger.Debug("%s Match in sub-range %v-%v for %s vs %s.", utils.CheckMark, sr.Start, sr.End, node1Name, node2Name) } } } -func safeCut(s string, n int) string { - if len(s) < n { - return s - } - return s[:n] -} - -func (t *TableDiffTask) getPkeyOffsets(ctx context.Context, pool *pgxpool.Pool) ([]Range, error) { - if len(t.Key) == 0 { - return nil, fmt.Errorf("primary key not defined for table %s.%s", t.Schema, t.Table) - } - - schemaIdent := sanitise(t.Schema) - tableIdent := sanitise(t.Table) - - quotedKeyCols := make([]string, len(t.Key)) - for i, k := range t.Key { - quotedKeyCols[i] = sanitise(k) - } - // Using this for both select and order by to keep things simple - pkStr := strings.Join(quotedKeyCols, ", ") - - querySQL := fmt.Sprintf("SELECT %s FROM %s.%s ORDER BY %s", pkStr, schemaIdent, tableIdent, pkStr) - - pgRows, err := pool.Query(ctx, querySQL) - if err != nil { - return nil, fmt.Errorf("failed to query primary keys for direct offset generation from %s.%s: %w", t.Schema, t.Table, err) - } - defer pgRows.Close() - - var allPks []any - numPKCols := len(t.Key) - scanDest := make([]any, numPKCols) - scanDestPtrs := make([]any, numPKCols) - for i := range scanDest { - scanDestPtrs[i] = &scanDest[i] - } - - for pgRows.Next() { - if err := pgRows.Scan(scanDestPtrs...); err != nil { - return nil, fmt.Errorf("failed to scan primary key value from %s.%s: %w", t.Schema, t.Table, err) - } - if numPKCols == 1 { - allPks = append(allPks, scanDest[0]) - } else { - /* PKs are composite, so create a new slice for each PK to avoid allPks - * elements pointing to the same underlying scanDest array. - */ - currentPkComposite := make([]any, numPKCols) - copy(currentPkComposite, scanDest) - allPks = append(allPks, currentPkComposite) - } - } - - if err := pgRows.Err(); err != nil { - return nil, fmt.Errorf("error iterating over primary key rows from %s.%s: %w", t.Schema, t.Table, err) - } - - if len(allPks) == 0 { - logger.Info("[%s.%s] No primary key values found, returning empty ranges for direct generation.", t.Schema, t.Table) - return []Range{}, nil - } - - var ranges []Range - // As always the case, our first range needs to be (NULL, first_pkey) - ranges = append(ranges, Range{Start: nil, End: allPks[0]}) - - currentPkIndex := 0 - for currentPkIndex < len(allPks) { - currentBlockStartPkey := allPks[currentPkIndex] - - nextPKeyIndexForRangeEnd := currentPkIndex + t.BlockSize - - if nextPKeyIndexForRangeEnd < len(allPks) { - rangeEndValue := allPks[nextPKeyIndexForRangeEnd] - ranges = append(ranges, Range{Start: currentBlockStartPkey, End: rangeEndValue}) - currentPkIndex = nextPKeyIndexForRangeEnd - } else { - if !(currentPkIndex == 0 && len(allPks) <= t.BlockSize && len(allPks) == 1) { - ranges = append(ranges, Range{Start: currentBlockStartPkey, End: nil}) - } - break - } - } - - logger.Debug("[%s.%s] Generated %d ranges without sampling from %d pkeys with block_size %d.", t.Schema, t.Table, len(ranges), len(allPks), t.BlockSize) - return ranges, nil -} - func (t *TableDiffTask) AddPrimaryKeyToDiffSummary() { if t.DiffResult.Summary.PrimaryKey == nil { t.DiffResult.Summary.PrimaryKey = t.Key diff --git a/internal/core/table_repair.go b/internal/core/table_repair.go index 362d9c5..04529dd 100644 --- a/internal/core/table_repair.go +++ b/internal/core/table_repair.go @@ -24,7 +24,8 @@ import ( "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/pgedge/ace/internal/auth" - "github.com/pgedge/ace/internal/logger" + utils "github.com/pgedge/ace/pkg/common" + "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" ) @@ -158,13 +159,13 @@ func (t *TableRepairTask) ValidateAndPrepare() error { // Reading nodelist is unnecessary since the diff file will contain that info. // TODO: Remove this once checks are handled correctly in readClusterInfo - nodeList, err := ParseNodes(t.Nodes) + nodeList, err := utils.ParseNodes(t.Nodes) if err != nil { return fmt.Errorf("nodes should be a comma-separated list of nodenames. E.g., nodes=\"n1,n2\". Error: %w", err) } t.NodeList = nodeList - if err := readClusterInfo(t); err != nil { + if err := utils.ReadClusterInfo(t); err != nil { return fmt.Errorf("failed to read cluster info: %w", err) } @@ -219,7 +220,7 @@ func (t *TableRepairTask) ValidateAndPrepare() error { for _, nodeMap := range t.ClusterNodes { if len(t.NodeList) > 0 { nameVal, _ := nodeMap["Name"].(string) - if !Contains(t.NodeList, nameVal) { + if !utils.Contains(t.NodeList, nameVal) { continue } } @@ -253,13 +254,13 @@ func (t *TableRepairTask) ValidateAndPrepare() error { } t.Pools[nodeName] = connPool - cols, err := GetColumns(connPool, t.Schema, t.Table) + cols, err := utils.GetColumns(connPool, t.Schema, t.Table) if err != nil { return fmt.Errorf("failed to get columns for %s.%s on node %s: %w", t.Schema, t.Table, nodeName, err) } t.Cols = cols - pKey, err := GetPrimaryKey(connPool, t.Schema, t.Table) + pKey, err := utils.GetPrimaryKey(connPool, t.Schema, t.Table) if err != nil { return fmt.Errorf("failed to get primary key for %s.%s on node %s: %w", t.Schema, t.Table, nodeName, err) } @@ -271,7 +272,7 @@ func (t *TableRepairTask) ValidateAndPrepare() error { publicIP, _ := nodeInfo["PublicIP"].(string) port, _ := nodeInfo["Port"].(string) - colTypes, err := GetColumnTypes(connPool, t.Table) + colTypes, err := utils.GetColumnTypes(connPool, t.Table) if err != nil { return fmt.Errorf("failed to get column types for %s on node %s: %w", t.Table, nodeName, err) } @@ -285,7 +286,7 @@ func (t *TableRepairTask) ValidateAndPrepare() error { dbUser = t.Database.DBUser } - authorized, missingPrivsMap, err := CheckUserPrivileges(connPool, dbUser, t.Schema, t.Table, requiredPrivileges) + authorized, missingPrivsMap, err := utils.CheckUserPrivileges(connPool, dbUser, t.Schema, t.Table, requiredPrivileges) if err != nil { return fmt.Errorf("failed to check user privileges on node %s: %w", nodeName, err) } @@ -702,7 +703,7 @@ func (t *TableRepairTask) runBidirectionalRepair() error { node1RowsByPKey := make(map[string]map[string]any) for _, row := range node1Rows { - pkeyStr, err := stringifyPKey(row, t.Key, t.SimplePrimaryKey) + pkeyStr, err := utils.StringifyKey(row, t.Key) if err != nil { repairErrors = append(repairErrors, fmt.Sprintf("stringify pkey failed for %s: %v", node1Name, err)) continue @@ -712,7 +713,7 @@ func (t *TableRepairTask) runBidirectionalRepair() error { node2RowsByPKey := make(map[string]map[string]any) for _, row := range node2Rows { - pkeyStr, err := stringifyPKey(row, t.Key, t.SimplePrimaryKey) + pkeyStr, err := utils.StringifyKey(row, t.Key) if err != nil { repairErrors = append(repairErrors, fmt.Sprintf("stringify pkey failed for %s: %v", node2Name, err)) continue @@ -723,14 +724,14 @@ func (t *TableRepairTask) runBidirectionalRepair() error { insertsForNode1 := make(map[string]map[string]any) for pkey, row := range node2RowsByPKey { if _, exists := node1RowsByPKey[pkey]; !exists { - insertsForNode1[pkey] = stripSpockMetadata(row) + insertsForNode1[pkey] = utils.StripSpockMetadata(row) } } insertsForNode2 := make(map[string]map[string]any) for pkey, row := range node1RowsByPKey { if _, exists := node2RowsByPKey[pkey]; !exists { - insertsForNode2[pkey] = stripSpockMetadata(row) + insertsForNode2[pkey] = utils.StripSpockMetadata(row) } } @@ -1015,7 +1016,7 @@ func executeUpserts(tx pgx.Tx, task *TableRepairTask, upserts map[string]map[str return 0, fmt.Errorf("type for column %s not found in target node's colTypes", colName) } - convertedVal, err := convertToPgxType(val, pgType) + convertedVal, err := utils.ConvertToPgxType(val, pgType) if err != nil { return 0, fmt.Errorf("error converting value for column %s (value: %v, type: %s): %w", colName, val, pgType, err) } @@ -1133,7 +1134,7 @@ func getDryRunOutput(task *TableRepairTask) (string, error) { node1RowsByPKey := make(map[string]map[string]any) for _, row := range node1Rows { - pkeyStr, err := stringifyPKey(row, task.Key, task.SimplePrimaryKey) + pkeyStr, err := utils.StringifyKey(row, task.Key) if err != nil { return "", fmt.Errorf("error stringifying pkey for row on %s: %w", node1Name, err) } @@ -1142,7 +1143,7 @@ func getDryRunOutput(task *TableRepairTask) (string, error) { node2RowsByPKey := make(map[string]map[string]any) for _, row := range node2Rows { - pkeyStr, err := stringifyPKey(row, task.Key, task.SimplePrimaryKey) + pkeyStr, err := utils.StringifyKey(row, task.Key) if err != nil { return "", fmt.Errorf("error stringifying pkey for row on %s: %w", node2Name, err) } @@ -1273,13 +1274,6 @@ func getDryRunOutput(task *TableRepairTask) (string, error) { return sb.String(), nil } -func stripSpockMetadata(row map[string]any) map[string]any { - if row != nil { - delete(row, "_spock_metadata_") - } - return row -} - func calculateRepairSets(task *TableRepairTask) (map[string]map[string]map[string]any, map[string]map[string]map[string]any, error) { fullRowsToUpsert := make(map[string]map[string]map[string]any) // nodeName -> string(pkey) -> rowData fullRowsToDelete := make(map[string]map[string]map[string]any) // nodeName -> string(pkey) -> rowData @@ -1318,22 +1312,22 @@ func calculateRepairSets(task *TableRepairTask) (map[string]map[string]map[strin sourceRowsByPKey := make(map[string]map[string]any) for _, row := range sourceRows { - pkeyStr, err := stringifyPKey(row, task.Key, task.SimplePrimaryKey) + pkeyStr, err := utils.StringifyKey(row, task.Key) if err != nil { return nil, nil, fmt.Errorf("error stringifying pkey for source row on %s: %w", task.SourceOfTruth, err) } - cleanRow := stripSpockMetadata(row) + cleanRow := utils.StripSpockMetadata(row) sourceRowsByPKey[pkeyStr] = cleanRow fullRowsToUpsert[targetNode][pkeyStr] = cleanRow } targetRowsByPKey := make(map[string]map[string]any) for _, row := range targetRows { - pkeyStr, err := stringifyPKey(row, task.Key, task.SimplePrimaryKey) + pkeyStr, err := utils.StringifyKey(row, task.Key) if err != nil { return nil, nil, fmt.Errorf("error stringifying pkey for target row on %s: %w", targetNode, err) } - cleanRow := stripSpockMetadata(row) + cleanRow := utils.StripSpockMetadata(row) targetRowsByPKey[pkeyStr] = cleanRow if _, existsInSource := sourceRowsByPKey[pkeyStr]; !existsInSource { fullRowsToDelete[targetNode][pkeyStr] = cleanRow diff --git a/internal/core/table_rerun.go b/internal/core/table_rerun.go index b9298f8..1d5453c 100644 --- a/internal/core/table_rerun.go +++ b/internal/core/table_rerun.go @@ -27,7 +27,8 @@ import ( "github.com/jackc/pgx/v4/pgxpool" "github.com/pgedge/ace/db/queries" "github.com/pgedge/ace/internal/auth" - "github.com/pgedge/ace/internal/logger" + utils "github.com/pgedge/ace/pkg/common" + "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" ) @@ -39,7 +40,7 @@ func (t *TableDiffTask) ExecuteRerunTask() error { } logger.Info("Successfully loaded and validated diff file: %s", t.DiffFilePath) - if err := readClusterInfo(t); err != nil { + if err := utils.ReadClusterInfo(t); err != nil { return fmt.Errorf("error loading cluster information for rerun: %w", err) } @@ -63,7 +64,7 @@ func (t *TableDiffTask) ExecuteRerunTask() error { pools := make(map[string]*pgxpool.Pool) for _, nodeInfo := range t.ClusterNodes { name := nodeInfo["Name"].(string) - if !Contains(t.NodeList, name) { + if !utils.Contains(t.NodeList, name) { continue } pool, err := auth.GetClusterNodeConnection(nodeInfo, t.ClientRole) @@ -148,7 +149,7 @@ func (t *TableDiffTask) ExecuteRerunTask() error { for _, count := range newDiffResult.Summary.DiffRowsCount { totalPersistentDiffs += count } - logger.Warn("%s Found %d persistent differences. Writing new report to %s", CrossMark, totalPersistentDiffs, outputFileName) + logger.Warn("%s Found %d persistent differences. Writing new report to %s", utils.CrossMark, totalPersistentDiffs, outputFileName) jsonData, mErr := json.MarshalIndent(newDiffResult, "", " ") if mErr != nil { return fmt.Errorf("failed to marshal new diff report: %w", mErr) @@ -157,7 +158,7 @@ func (t *TableDiffTask) ExecuteRerunTask() error { return fmt.Errorf("failed to write new diff report: %w", wErr) } } else { - logger.Info("%s All previously reported differences have been resolved.", CheckMark) + logger.Info("%s All previously reported differences have been resolved.", utils.CheckMark) } return nil @@ -213,7 +214,7 @@ func (t *TableDiffTask) collectPkeysFromDiff() (map[string]map[string]any, error return nil, fmt.Errorf("primary key column '%s' not found in a diff row", pkCol) } } - pkStr, err := StringifyKey(pkVal, t.Key) + pkStr, err := utils.StringifyKey(pkVal, t.Key) if err != nil { return nil, fmt.Errorf("failed to stringify key: %w", err) } @@ -226,6 +227,7 @@ func (t *TableDiffTask) collectPkeysFromDiff() (map[string]map[string]any, error // fetchRowsByPkeys efficiently fetches a list of rows from a node by their primary keys. // It uses a temporary table and a JOIN for high performance with large numbers of keys. +// TODO: Can this be separated out into a common function that can be used by other tasks? func fetchRowsByPkeys(ctx context.Context, pool *pgxpool.Pool, t *TableDiffTask, pkeyVals [][]any) (map[string]map[string]any, error) { if len(pkeyVals) == 0 { return make(map[string]map[string]any), nil @@ -301,7 +303,7 @@ func fetchRowsByPkeys(ctx context.Context, pool *pgxpool.Pool, t *TableDiffTask, for _, pkCol := range t.Key { pkMap[pkCol] = rowData[pkCol] } - pkStr, err := StringifyKey(pkMap, t.Key) + pkStr, err := utils.StringifyKey(pkMap, t.Key) if err != nil { return nil, fmt.Errorf("failed to stringify fetched row key: %w", err) } @@ -346,7 +348,7 @@ func (t *TableDiffTask) reCompareDiffs(fetchedRowsByNode map[string]map[string]m for _, pkCol := range t.Key { pkMap[pkCol] = row[pkCol] } - pkStr, _ := StringifyKey(pkMap, t.Key) + pkStr, _ := utils.StringifyKey(pkMap, t.Key) originalNode1Rows[pkStr] = row allPkeysForPair[pkStr] = true } @@ -355,7 +357,7 @@ func (t *TableDiffTask) reCompareDiffs(fetchedRowsByNode map[string]map[string]m for _, pkCol := range t.Key { pkMap[pkCol] = row[pkCol] } - pkStr, _ := StringifyKey(pkMap, t.Key) + pkStr, _ := utils.StringifyKey(pkMap, t.Key) originalNode2Rows[pkStr] = row allPkeysForPair[pkStr] = true } @@ -386,10 +388,10 @@ func (t *TableDiffTask) reCompareDiffs(fetchedRowsByNode map[string]map[string]m if isDifferent { persistentDiffCount++ if nowOnNode1 { - newDiffsForPair.Rows[node1] = append(newDiffsForPair.Rows[node1], addSpockMetadata(newRow1)) + newDiffsForPair.Rows[node1] = append(newDiffsForPair.Rows[node1], utils.AddSpockMetadata(newRow1)) } if nowOnNode2 { - newDiffsForPair.Rows[node2] = append(newDiffsForPair.Rows[node2], addSpockMetadata(newRow2)) + newDiffsForPair.Rows[node2] = append(newDiffsForPair.Rows[node2], utils.AddSpockMetadata(newRow2)) } } } diff --git a/internal/core/utils.go b/pkg/common/utils.go similarity index 81% rename from internal/core/utils.go rename to pkg/common/utils.go index 97213f0..3958591 100644 --- a/internal/core/utils.go +++ b/pkg/common/utils.go @@ -9,7 +9,7 @@ // ///////////////////////////////////////////////////////////////////////////// -package core +package common import ( "context" @@ -26,7 +26,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/pgxpool" "github.com/pgedge/ace/db/queries" - "github.com/pgedge/ace/internal/logger" + "github.com/pgedge/ace/pkg/logger" "github.com/pgedge/ace/pkg/types" ) @@ -227,7 +227,7 @@ func Contains(slice []string, value string) bool { return slices.Contains(slice, value) } -func readClusterInfo(t ClusterConfigProvider) error { +func ReadClusterInfo(t ClusterConfigProvider) error { configPath := fmt.Sprintf("%s.json", t.GetClusterName()) if _, err := os.Stat(configPath); os.IsNotExist(err) { return fmt.Errorf("cluster configuration file not found for %s", t.GetClusterName()) @@ -319,8 +319,8 @@ func readClusterInfo(t ClusterConfigProvider) error { return nil } -// convertToPgxType converts a value from a JSON unmarshal to a type that pgx can handle -func convertToPgxType(val interface{}, pgType string) (interface{}, error) { +// ConvertToPgxType converts a value from a JSON unmarshal to a type that pgx can handle +func ConvertToPgxType(val any, pgType string) (any, error) { if val == nil { return nil, nil } @@ -461,32 +461,112 @@ func convertToPgxType(val interface{}, pgType string) (interface{}, error) { // TODO: Array handling! } -// stringifyPKey converts a primary key (simple or composite) from a row map into a consistent string representation. -func stringifyPKey(row map[string]any, keyCols []string, isSimplePK bool) (string, error) { - if len(keyCols) == 0 { - return "", fmt.Errorf("no primary key columns defined") +func SafeCut(s string, n int) string { + if len(s) < n { + return s } + return s[:n] +} + +func IsKnownScalarType(colType string) bool { + knownPrefixes := []string{ + "character", "text", + "integer", "bigint", "smallint", + "numeric", "decimal", "real", "double precision", + "boolean", + "bytea", + "json", "jsonb", + "uuid", + "timestamp", "date", "time", + } + for _, prefix := range knownPrefixes { + if strings.HasPrefix(colType, prefix) { + return true + } + } + return false +} + +func StringifyKey(row map[string]any, pKeyCols []string) (string, error) { + if len(pKeyCols) == 0 { + return "", nil + } + // if len(pkValues) != len(pkCols) { + // return "", fmt.Errorf("mismatch between pk value count (%d) and pk column count (%d)", len(pkValues), len(pkCols)) + // } - if isSimplePK { - val, ok := row[keyCols[0]] + if len(pKeyCols) == 1 { + val, ok := row[pKeyCols[0]] if !ok { - return "", fmt.Errorf("primary key column %s not found in row", keyCols[0]) + return "", fmt.Errorf("pk column '%s' not found in pk values map", pKeyCols[0]) } return fmt.Sprintf("%v", val), nil } - // Composite key: sort and then join for consistency - sortedKeyCols := make([]string, len(keyCols)) - copy(sortedKeyCols, keyCols) - sort.Strings(sortedKeyCols) + sortedPkCols := make([]string, len(pKeyCols)) + copy(sortedPkCols, pKeyCols) + sort.Strings(sortedPkCols) - var pkParts []string - for _, colName := range sortedKeyCols { - val, ok := row[colName] + var parts []string + for _, col := range sortedPkCols { + val, ok := row[col] if !ok { - return "", fmt.Errorf("primary key column %s not found in row", colName) + return "", fmt.Errorf("pk column '%s' not found in pk values map", col) } - pkParts = append(pkParts, fmt.Sprintf("%v", val)) + parts = append(parts, fmt.Sprintf("%v", val)) } - return strings.Join(pkParts, "||"), nil + return strings.Join(parts, "||"), nil +} + +func AddSpockMetadata(row map[string]any) map[string]any { + if row == nil { + return nil + } + metadata := make(map[string]any) + if commitTs, ok := row["commit_ts"]; ok { + metadata["commit_ts"] = commitTs + delete(row, "commit_ts") + } + if nodeOrigin, ok := row["node_origin"]; ok { + metadata["node_origin"] = nodeOrigin + delete(row, "node_origin") + } + row["_spock_metadata_"] = metadata + return row +} + +func StripSpockMetadata(row map[string]any) map[string]any { + if row != nil { + delete(row, "_spock_metadata_") + } + return row +} + +func DiffStringSlices(a, b []string) (missing, extra []string) { + sort.Strings(a) + sort.Strings(b) + + aMap := make(map[string]struct{}, len(a)) + for _, s := range a { + aMap[s] = struct{}{} + } + + bMap := make(map[string]struct{}, len(b)) + for _, s := range b { + bMap[s] = struct{}{} + } + + for _, s := range a { + if _, found := bMap[s]; !found { + missing = append(missing, s) + } + } + + for _, s := range b { + if _, found := aMap[s]; !found { + extra = append(extra, s) + } + } + + return missing, extra } diff --git a/internal/logger/logger.go b/pkg/logger/logger.go similarity index 100% rename from internal/logger/logger.go rename to pkg/logger/logger.go