diff --git a/internal/core/table_diff.go b/internal/core/table_diff.go index da5321a..6546994 100644 --- a/internal/core/table_diff.go +++ b/internal/core/table_diff.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -721,10 +722,9 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { logger.Info("Table %s is comparable across nodes", t.QualifiedTableName) - // // TODO: add this back later - // if err := t.CheckColumnSize(); err != nil { - // return err - // } + if err := t.CheckColumnSize(); err != nil { + return err + } if t.DiffFilePath != "" { if err := CheckDiffFileFormat(t.DiffFilePath, t); err != nil { @@ -739,64 +739,59 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { return nil } -// func (t *TableDiffTask) CheckColumnSize() error { -// for hostPort, types := range t.ColTypes { -// parts := strings.Split(hostPort, ":") -// if len(parts) != 2 { -// continue -// } - -// host, portStr := parts[0], parts[1] -// port, _ := strconv.Atoi(portStr) - -// var pool *pgxpool.Pool -// for _, nodeInfo := range t.ClusterNodes { -// nodeHost, _ := nodeInfo["public_ip"].(string) -// nodePort, _ := nodeInfo["port"].(float64) -// if nodePort == 0 { -// nodePort = 5432 -// } - -// if nodeHost == host && int(nodePort) == port { -// conn, err := auth.GetClusterNodeConnection(nodeInfo, t.ClientRole) -// if err != nil { -// return fmt.Errorf("failed to connect to node %s:%d: %w", host, port, err) -// } -// defer conn.Close() -// pool = conn -// break -// } -// } - -// if pool == nil { -// continue -// } - -// for colName, colType := range types { -// if !strings.Contains(colType, "bytea") { -// continue -// } - -// var avgSize int64 -// q := queries.NewQuerier(pool) -// err := q.CheckColumnSize(context.Background(), queries.CheckColumnSizeParams{ -// Schema: t.Schema, -// Table: t.Table, -// Column: colName, -// }) -// if err != nil { -// return fmt.Errorf("failed to check size of bytea column %s: %w", colName, err) -// } - -// if avgSize > 1000000 { -// return fmt.Errorf("refusing to perform table-diff. Data in column %s of table %s.%s is larger than 1 MB", -// colName, t.Schema, t.Table) -// } -// } -// } - -// return nil -// } +func (t *TableDiffTask) CheckColumnSize() error { + for hostPort, types := range t.ColTypes { + parts := strings.Split(hostPort, ":") + if len(parts) != 2 { + continue + } + + host, portStr := parts[0], parts[1] + port, _ := strconv.Atoi(portStr) + + var pool *pgxpool.Pool + for _, nodeInfo := range t.ClusterNodes { + nodeHost, _ := nodeInfo["PublicIP"].(string) + nodePort, _ := nodeInfo["Port"].(float64) + if nodePort == 0 { + nodePort = 5432 + } + + if nodeHost == host && int(nodePort) == port { + conn, err := auth.GetClusterNodeConnection(nodeInfo, t.ClientRole) + if err != nil { + return fmt.Errorf("failed to connect to node %s:%d: %w", host, port, err) + } + defer conn.Close() + pool = conn + break + } + } + + if pool == nil { + continue + } + + for colName, colType := range types { + if !strings.Contains(colType, "bytea") { + continue + } + + avgSize, err := helpers.AvgColumnSize(context.Background(), pool, t.Schema, t.Table, colName) + logger.Debug("Column %s of table %s.%s has average size %d", colName, t.Schema, t.Table, avgSize) + if err != nil { + return fmt.Errorf("failed to check size of bytea column %s: %w", colName, err) + } + + if avgSize > 1000000 { + return fmt.Errorf("refusing to perform table-diff. Data in column %s of table %s.%s is larger than 1 MB", + colName, t.Schema, t.Table) + } + } + } + + return nil +} func (t *TableDiffTask) ExecuteTask(debugMode bool) error { startTime := time.Now() @@ -1181,7 +1176,6 @@ func (t *TableDiffTask) hashRange( return hash, nil } -// TODO: This could be optimised further func (t *TableDiffTask) generateSubRanges( ctx context.Context, node string, @@ -1199,21 +1193,65 @@ func (t *TableDiffTask) generateSubRanges( return nil, fmt.Errorf("no pool for node %s", node) } - pkCol := sanitise(t.Key[0]) + quotedKeyCols := make([]string, len(t.Key)) + for i, k := range t.Key { + quotedKeyCols[i] = sanitise(k) + } + pkColsStr := strings.Join(quotedKeyCols, ", ") + pkTupleStr := fmt.Sprintf("ROW(%s)", pkColsStr) schemaTable := fmt.Sprintf("%s.%s", sanitise(t.Schema), sanitise(t.Table)) var conditions []string args := []any{} paramIdx := 1 + if parentRange.Start != nil { - conditions = append(conditions, fmt.Sprintf("%s >= $%d", pkCol, paramIdx)) - args = append(args, parentRange.Start) - paramIdx++ + startVal := parentRange.Start + if len(t.Key) == 1 { + conditions = append(conditions, fmt.Sprintf("%s >= $%d", quotedKeyCols[0], paramIdx)) + args = append(args, startVal) + paramIdx++ + } else { + startVals, ok := startVal.([]any) + if !ok { + return nil, fmt.Errorf("generateSubRanges: parentRange.Start is not a valid composite key") + } + placeholders := make([]string, len(t.Key)) + for i := 0; i < len(t.Key); i++ { + placeholders[i] = fmt.Sprintf("$%d", paramIdx+i) + } + placeholderTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(placeholders, ", ")) + conditions = append(conditions, fmt.Sprintf("%s >= %s", pkTupleStr, placeholderTupleStr)) + args = append(args, startVals...) + paramIdx += len(t.Key) + } } + if parentRange.End != nil { - conditions = append(conditions, fmt.Sprintf("%s <= $%d", pkCol, paramIdx)) - args = append(args, parentRange.End) + endVal := parentRange.End + if len(t.Key) == 1 { + // In hashRange, the end is exclusive (<), but here for counting and splitting + // we use inclusive (<=) to match fetchRows. This is acceptable because + // we are splitting a mismatched range, and slight overlap is okay. + conditions = append(conditions, fmt.Sprintf("%s <= $%d", quotedKeyCols[0], paramIdx)) + args = append(args, endVal) + paramIdx++ + } else { + endVals, ok := endVal.([]any) + if !ok { + return nil, fmt.Errorf("generateSubRanges: parentRange.End is not a valid composite key") + } + placeholders := make([]string, len(t.Key)) + for i := 0; i < len(t.Key); i++ { + placeholders[i] = fmt.Sprintf("$%d", paramIdx+i) + } + placeholderTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(placeholders, ", ")) + conditions = append(conditions, fmt.Sprintf("%s <= %s", pkTupleStr, placeholderTupleStr)) + args = append(args, endVals...) + paramIdx += len(t.Key) + } } + whereClause := "" if len(conditions) > 0 { whereClause = "WHERE " + strings.Join(conditions, " AND ") @@ -1238,17 +1276,32 @@ func (t *TableDiffTask) generateSubRanges( sqlOffset = 0 } - orderByPK := sanitise(t.Key[0]) + orderByPK := strings.Join(quotedKeyCols, ", ") medianQueryArgs := make([]any, len(args)) copy(medianQueryArgs, args) medianQueryArgs = append(medianQueryArgs, sqlOffset) medianQuery := fmt.Sprintf("SELECT %s FROM %s %s ORDER BY %s LIMIT 1 OFFSET $%d", - pkCol, schemaTable, whereClause, orderByPK, paramIdx) + pkColsStr, schemaTable, whereClause, orderByPK, paramIdx) var medianPKVal any - err = pool.QueryRow(ctx, medianQuery, medianQueryArgs...).Scan(&medianPKVal) + numPKCols := len(t.Key) + + if numPKCols == 1 { + err = pool.QueryRow(ctx, medianQuery, medianQueryArgs...).Scan(&medianPKVal) + } else { + scanDest := make([]any, numPKCols) + scanDestPtrs := make([]any, numPKCols) + for i := range scanDest { + scanDestPtrs[i] = &scanDest[i] + } + err = pool.QueryRow(ctx, medianQuery, medianQueryArgs...).Scan(scanDestPtrs...) + if err == nil { + medianPKVal = append([]any{}, scanDest...) + } + } + if err != nil { logger.Debug("[%s] Failed to find median PK for range %v-%v: %v. SQL: %s, Args: %v", node, parentRange.Start, parentRange.End, err, medianQuery, medianQueryArgs) return []Range{parentRange}, nil