Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 127 additions & 39 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ func SetupCLI() *cli.App {
Name: "concurrency-factor",
Aliases: []string{"c"},
Usage: "Concurrency factor",
Value: 2,
Value: 1,
},
&cli.IntFlag{
Name: "compare-unit-size",
Aliases: []string{"s"},
Usage: "Compare unit size",
Usage: "Max size of the smallest block to use when diffs are present",
Value: 10000,
},
&cli.StringFlag{
Expand All @@ -43,17 +43,12 @@ func SetupCLI() *cli.App {
&cli.StringFlag{
Name: "nodes",
Aliases: []string{"n"},
Usage: "Nodes to include in the diff",
Usage: "Nodes to include in the diff (default: all)",
Value: "all",
},
&cli.StringFlag{
Name: "batch-size",
Usage: "Size of each batch",
Value: "10",
},
&cli.StringFlag{
Name: "table-filter",
Usage: "Filter expression for tables",
Usage: "Where clause expression to use while diffing tables",
Value: "",
},
&cli.BoolFlag{
Expand All @@ -73,6 +68,69 @@ func SetupCLI() *cli.App {
Value: false,
},
}

tr_flags := []cli.Flag{
&cli.StringFlag{
Name: "dbname",
Aliases: []string{"d"},
Usage: "Name of the database",
Value: "",
},
&cli.StringFlag{
Name: "diff-file",
Aliases: []string{"f"},
Usage: "Path to the diff file (required)",
Required: true,
},
&cli.StringFlag{
Name: "source-of-truth",
Aliases: []string{"s"},
Usage: "Name of the node to be considered the source of truth",
},
&cli.StringFlag{
Name: "nodes",
Aliases: []string{"n"},
Usage: "Nodes to include for cluster info (default: all)",
Value: "all",
},
&cli.BoolFlag{
Name: "quiet",
Usage: "Whether to suppress output",
Value: false,
},
&cli.BoolFlag{
Name: "debug",
Aliases: []string{"v"},
Usage: "Enable debug logging",
Value: false,
},
&cli.BoolFlag{
Name: "dry-run",
Usage: "Show what would be done without executing",
Value: false,
},
&cli.BoolFlag{
Name: "insert-only",
Usage: "Only perform inserts, no updates or deletes",
Value: false,
},
&cli.BoolFlag{
Name: "upsert-only",
Usage: "Only perform upserts (insert or update), no deletes",
Value: false,
},
&cli.BoolFlag{
Name: "fire-triggers",
Usage: "Whether to fire triggers during repairs",
Value: false,
},
&cli.BoolFlag{
Name: "bidirectional",
Usage: "Whether to perform repairs in both directions. Can be used only with the insert-only option",
Value: false,
},
}

app := &cli.App{
Name: "ace",
Usage: "Advanced Command-line Executor for database operations",
Expand All @@ -85,62 +143,57 @@ func SetupCLI() *cli.App {
"and detecting data inconsistencies",
Action: func(ctx *cli.Context) error {
if ctx.Args().Len() < 2 {
return fmt.Errorf("missing required arguments: needs <cluster> and <table>")
return fmt.Errorf("missing required arguments for table-diff: needs <cluster> and <table>")
}
return TableDiffCLI(ctx)
},
Flags: td_flags,
},
{
Name: "table-repair",
Usage: "Repair table inconsistencies based on a diff file",
ArgsUsage: "<cluster> <table>",
Flags: tr_flags,
Action: func(ctx *cli.Context) error {
if ctx.Args().Len() < 2 {
return fmt.Errorf("missing required arguments for table-repair: needs <cluster> and <table>")
}
return TableRepairCLI(ctx)
},
},
},
}

return app
}

func TableDiffCLI(ctx *cli.Context) error {
clusterName := ctx.Args().Get(0)
tableName := ctx.Args().Get(1)
dbName := ctx.String("dbname")
blockSizeStr := ctx.String("block-size")
compareUnitSize := ctx.Int("compare-unit-size")
debugMode := ctx.Bool("debug")
concurrencyFactor := ctx.Int("concurrency-factor")
outputFormat := ctx.String("output")
nodesParam := ctx.String("nodes")
batchSizeStr := ctx.String("batch-size")
tableFilter := ctx.String("table-filter")
quietMode := ctx.Bool("quiet")
overrideBlockSize := ctx.Bool("override-block-size")

blockSizeInt, err := strconv.ParseInt(blockSizeStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid block size '%s': %v", blockSizeStr, err)
}

batchSizeInt, err := strconv.ParseInt(batchSizeStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid batch size '%s': %v", batchSizeStr, err)
}
debugMode := ctx.Bool("debug")

task := core.NewTableDiffTask()
task.ClusterName = clusterName
task.TableName = tableName
task.DBName = dbName
task.ClusterName = ctx.Args().Get(0)
task.QualifiedTableName = ctx.Args().Get(1)
task.DBName = ctx.String("dbname")
task.BlockSize = int(blockSizeInt)
task.ConcurrencyFactor = concurrencyFactor
task.CompareUnitSize = compareUnitSize
task.Output = outputFormat
task.Nodes = nodesParam
task.BatchSize = int(batchSizeInt)
task.TableFilter = tableFilter
task.QuietMode = quietMode
task.OverrideBlockSize = overrideBlockSize
task.ConcurrencyFactor = ctx.Int("concurrency-factor")
task.CompareUnitSize = ctx.Int("compare-unit-size")
task.Output = ctx.String("output")
task.Nodes = ctx.String("nodes")
task.TableFilter = ctx.String("table-filter")
task.QuietMode = ctx.Bool("quiet")
task.OverrideBlockSize = ctx.Bool("override-block-size")

if err := task.Validate(); err != nil {
return fmt.Errorf("validation failed: %v", err)
}

if err := task.RunChecks(true); err != nil {
if err := task.RunChecks(true); err != nil { // Pass true to skip inner validation if Validate() was just called
return fmt.Errorf("checks failed: %v", err)
}

Expand All @@ -151,3 +204,38 @@ func TableDiffCLI(ctx *cli.Context) error {
fmt.Println("Table diff completed")
return nil
}

func TableRepairCLI(ctx *cli.Context) error {
task := core.NewTableRepairTask()
task.ClusterName = ctx.Args().Get(0)
task.QualifiedTableName = ctx.Args().Get(1)
task.DiffFilePath = ctx.String("diff-file")
task.DBName = ctx.String("dbname")
task.Nodes = ctx.String("nodes")
task.SourceOfTruth = ctx.String("source-of-truth")
task.QuietMode = ctx.Bool("quiet")

if ctx.Bool("debug") {
core.SetGlobalLogLevel(core.LevelDebug)
} else {
core.SetGlobalLogLevel(core.LevelInfo)
}

task.DryRun = ctx.Bool("dry-run")
task.InsertOnly = ctx.Bool("insert-only")
task.UpsertOnly = ctx.Bool("upsert-only")
task.FireTriggers = ctx.Bool("fire-triggers")
task.FixNulls = ctx.Bool("fix-nulls")
task.Bidirectional = ctx.Bool("bidirectional")

if err := task.ValidateAndPrepare(); err != nil {
return fmt.Errorf("validation failed: %v", err)
}

if err := task.Run(true); err != nil {
return fmt.Errorf("error during table repair: %v", err)
}

fmt.Println("Table repair process initiated.")
return nil
}
64 changes: 24 additions & 40 deletions internal/core/table_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ type Range struct {
type TableDiffTask struct {
types.Task
types.DerivedFields
TableName string
DBName string
Nodes string
QualifiedTableName string
DBName string
Nodes string

BlockSize int
ConcurrencyFactor int
Output string
BatchSize int
TableFilter string
QuietMode bool

Expand All @@ -76,37 +75,22 @@ type TableDiffTask struct {

CompareUnitSize int

DiffResult DiffOutput
DiffResult types.DiffOutput
diffMutex sync.Mutex
}

type DiffOutput struct {
NodeDiffs map[string]DiffByNodePair `json:"diffs"` // Key: "nodeA/nodeB" (sorted names)
Summary DiffSummary `json:"summary"`
}

// DiffByNodePair holds the differing rows for a pair of nodes.
// The keys in the DiffOutput.Diffs map will be "nodeX/nodeY",
// and the Node1/Node2 fields here will store rows corresponding to nodeX and nodeY respectively.
type DiffByNodePair struct {
Rows map[string][]map[string]any `json:"rows"` // Keyed by actual node name e.g. "n1", "n2"
}

type DiffSummary struct {
Schema string `json:"schema"`
Table string `json:"table"`
Nodes []string `json:"nodes"`
BlockSize int `json:"block_size"`
CompareUnitSize int `json:"compare_unit_size"`
ConcurrencyFactor int `json:"concurrency_factor"`
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
TimeTaken string `json:"time_taken"`
DiffRowsCount map[string]int `json:"diff_rows_count"` // Key: "nodeA/nodeB", Value: count of differing rows
TotalRowsChecked int64 `json:"total_rows_checked"` // Estimated
InitialRangesCount int `json:"initial_ranges_count"`
MismatchedRangesCount int `json:"mismatched_ranges_count"`
// Implement ClusterConfigProvider interface for TableDiffTask
func (t *TableDiffTask) GetClusterName() string { return t.ClusterName }
func (t *TableDiffTask) GetDBName() string { return t.DBName }
func (t *TableDiffTask) SetDBName(name string) { t.DBName = name }
func (t *TableDiffTask) GetNodes() string { return t.Nodes }
func (t *TableDiffTask) GetNodeList() []string { return t.NodeList }
func (t *TableDiffTask) SetNodeList(nl []string) { t.NodeList = nl }
func (t *TableDiffTask) SetDatabase(db types.Database) { t.Database = db }
func (t *TableDiffTask) GetClusterNodes() []map[string]any {
return t.ClusterNodes
}
func (t *TableDiffTask) SetClusterNodes(cn []map[string]any) { t.ClusterNodes = cn }

type NodePairDiff struct {
Node1OnlyRows []map[string]any
Expand Down Expand Up @@ -470,7 +454,7 @@ func (t *TableDiffTask) compareBlocks(
}

func (t *TableDiffTask) Validate() error {
if t.ClusterName == "" || t.TableName == "" {
if t.ClusterName == "" || t.QualifiedTableName == "" {
return fmt.Errorf("cluster_name and table_name are required arguments")
}

Expand Down Expand Up @@ -511,9 +495,9 @@ func (t *TableDiffTask) Validate() error {

logger.Info("Cluster %s exists", t.ClusterName)

parts := strings.Split(t.TableName, ".")
parts := strings.Split(t.QualifiedTableName, ".")
if len(parts) != 2 {
return fmt.Errorf("tableName %s must be of form 'schema.table_name'", t.TableName)
return fmt.Errorf("tableName %s must be of form 'schema.table_name'", t.QualifiedTableName)
}
schema, table := parts[0], parts[1]

Expand Down Expand Up @@ -735,7 +719,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error {
}
}

logger.Info("Table %s is comparable across nodes", t.TableName)
logger.Info("Table %s is comparable across nodes", t.QualifiedTableName)

// // TODO: add this back later
// if err := t.CheckColumnSize(); err != nil {
Expand Down Expand Up @@ -877,9 +861,9 @@ func (t *TableDiffTask) ExecuteTask(debugMode bool) error {
return fmt.Errorf("Unable to determine node with highest row count (or any row counts)")
}

t.DiffResult = DiffOutput{
NodeDiffs: make(map[string]DiffByNodePair),
Summary: DiffSummary{
t.DiffResult = types.DiffOutput{
NodeDiffs: make(map[string]types.DiffByNodePair),
Summary: types.DiffSummary{
Schema: t.Schema,
Table: t.Table,
Nodes: t.NodeList,
Expand Down Expand Up @@ -1092,7 +1076,7 @@ func (t *TableDiffTask) ExecuteTask(debugMode bool) error {
diffWg.Wait()
// t.DiffResult.Summary.MismatchedRangesCount = int(mismatchedRangesCountAtomic)

logger.Info("Table diff comparison completed for %s", t.TableName)
logger.Info("Table diff comparison completed for %s", t.QualifiedTableName)

endTime := time.Now()
t.DiffResult.Summary.EndTime = endTime.Format(time.RFC3339)
Expand Down Expand Up @@ -1338,7 +1322,7 @@ func (t *TableDiffTask) recursiveDiff(
t.diffMutex.Lock()

if _, ok := t.DiffResult.NodeDiffs[pairKey]; !ok {
t.DiffResult.NodeDiffs[pairKey] = DiffByNodePair{
t.DiffResult.NodeDiffs[pairKey] = types.DiffByNodePair{
Rows: make(map[string][]map[string]any),
}
}
Expand Down
Loading