diff --git a/db/helpers/sqlhelpers.go b/db/helpers/sqlhelpers.go index 203aa0b..3c7ffd4 100644 --- a/db/helpers/sqlhelpers.go +++ b/db/helpers/sqlhelpers.go @@ -99,7 +99,7 @@ func SanitiseIdentifier(input string) error { return nil } -func AvgColumnSize(ctx context.Context, db DBQuerier, schema, table, column string) (int64, error) { +func MaxColumnSize(ctx context.Context, db DBQuerier, schema, table, column string) (int64, error) { if err := SanitiseIdentifier(schema); err != nil { return 0, err } @@ -115,14 +115,14 @@ func AvgColumnSize(ctx context.Context, db DBQuerier, schema, table, column stri colIdent := fmt.Sprintf(`"%s"`, column) query := fmt.Sprintf( - `SELECT COALESCE(AVG(pg_column_size(%s)), 0) FROM %s.%s`, + `SELECT COALESCE(MAX(octet_length(%s))::bigint, 0) FROM %s.%s`, colIdent, schemaIdent, tableIdent, ) - var avgSize int64 - if err := db.QueryRow(ctx, query).Scan(&avgSize); err != nil { + var maxSize int64 + if err := db.QueryRow(ctx, query).Scan(&maxSize); err != nil { return 0, fmt.Errorf( - "AvgColumnSize query failed for %s.%s.%s: %w", + "MaxColumnSize query failed for %s.%s.%s: %w", schema, table, column, @@ -130,7 +130,7 @@ func AvgColumnSize(ctx context.Context, db DBQuerier, schema, table, column stri ) } - return avgSize, nil + return maxSize, nil } func GeneratePkeyOffsetsQuery( diff --git a/internal/core/table_diff.go b/internal/core/table_diff.go index 870f498..998fc33 100644 --- a/internal/core/table_diff.go +++ b/internal/core/table_diff.go @@ -573,17 +573,9 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { hostIP, _ := nodeInfo["PublicIP"].(string) user, _ := nodeInfo["DBUser"].(string) - var port float64 - portVal, ok := nodeInfo["Port"] - if ok { - switch v := portVal.(type) { - case string: - port, _ = strconv.ParseFloat(v, 64) - case float64: - port = v - } - } else { - port = 5432 + port, ok := nodeInfo["Port"].(string) + if !ok { + port = "5432" } if !Contains(t.NodeList, hostname) { @@ -629,7 +621,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { return fmt.Errorf("failed to get column types for table %s on node %s: %w", table, hostname, err) } - colTypesKey := fmt.Sprintf("%s:%d", hostIP, int(port)) + colTypesKey := fmt.Sprintf("%s:%s", hostIP, port) if t.ColTypes == nil { t.ColTypes = make(map[string]map[string]string) @@ -655,7 +647,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error { user, strings.Join(missingPrivs, ", "), schema, table, hostname) } - hostMap[hostIP+":"+fmt.Sprint(int(port))] = hostname + hostMap[hostIP+":"+port] = hostname if t.TableFilter != "" { viewName := fmt.Sprintf("%s_%s_filtered", t.TaskID, table) @@ -755,18 +747,27 @@ func (t *TableDiffTask) CheckColumnSize() error { var pool *pgxpool.Pool for _, nodeInfo := range t.ClusterNodes { nodeHost, _ := nodeInfo["PublicIP"].(string) - nodePort, _ := nodeInfo["Port"].(float64) + + var nodePort float64 + if nodePortVal, ok := nodeInfo["Port"]; ok { + switch v := nodePortVal.(type) { + case string: + nodePort, _ = strconv.ParseFloat(v, 64) + case float64: + nodePort = v + } + } + if nodePort == 0 { nodePort = 5432 } if nodeHost == host && int(nodePort) == port { - conn, err := auth.GetClusterNodeConnection(nodeInfo, t.ClientRole) + var err error + pool, err = auth.GetClusterNodeConnection(nodeInfo, t.ClientRole) if err != nil { return fmt.Errorf("failed to connect to node %s:%d: %w", host, port, err) } - defer conn.Close() - pool = conn break } } @@ -780,17 +781,20 @@ func (t *TableDiffTask) CheckColumnSize() error { continue } - avgSize, err := helpers.AvgColumnSize(context.Background(), pool, t.Schema, t.Table, colName) - logger.Debug("Column %s of table %s.%s has average size %d", colName, t.Schema, t.Table, avgSize) + maxSize, err := helpers.MaxColumnSize(context.Background(), pool, t.Schema, t.Table, colName) + logger.Debug("Column %s of table %s.%s has max size %d", colName, t.Schema, t.Table, maxSize) if err != nil { + pool.Close() return fmt.Errorf("failed to check size of bytea column %s: %w", colName, err) } - if avgSize > 1000000 { + if maxSize > 1000000 { + pool.Close() return fmt.Errorf("refusing to perform table-diff. Data in column %s of table %s.%s is larger than 1 MB", colName, t.Schema, t.Table) } } + pool.Close() } return nil diff --git a/tests/integration/table_diff_integration_test.go b/tests/integration/table_diff_integration_test.go index e5d177f..4dbfddf 100644 --- a/tests/integration/table_diff_integration_test.go +++ b/tests/integration/table_diff_integration_test.go @@ -900,3 +900,94 @@ func TestTableDiff_TableFiltering(t *testing.T) { log.Println("TestTableDiff_TableFiltering completed.") } + +func TestTableDiff_ByteaColumnSizeCheck(t *testing.T) { + ctx := context.Background() + tableName := "bytea_size_test" + qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + + createTableSQL := fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( + id INT PRIMARY KEY, + data BYTEA +);`, qualifiedTableName) + + // Create table on both nodes and add cleanup to drop it + for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + _, err := pool.Exec(ctx, createTableSQL) + if err != nil { + t.Fatalf("Failed to create test table %s: %v", qualifiedTableName, err) + } + } + t.Cleanup(func() { + for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", qualifiedTableName)) + if err != nil { + t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) + } + } + }) + + // --- Test Case 1: Data < 1MB (should pass) --- + t.Run("DataUnder1MB", func(t *testing.T) { + // Truncate before run + for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName)) + if err != nil { + t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err) + } + } + + smallData := make([]byte, 500*1024) // 500 KB + _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), smallData) + if err != nil { + t.Fatalf("Failed to insert small data: %v", err) + } + _, err = pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), smallData) + if err != nil { + t.Fatalf("Failed to insert small data: %v", err) + } + + nodesToCompare := []string{serviceN1, serviceN2} + tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + + err = tdTask.RunChecks(false) + if err != nil { + t.Errorf("RunChecks should succeed for bytea data < 1MB, but got error: %v", err) + } + }) + + // --- Test Case 2: Data > 1MB (should fail) --- + t.Run("DataOver1MB", func(t *testing.T) { + // Truncate before run + for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName)) + if err != nil { + t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err) + } + } + largeData := make([]byte, 1024*1024+1) // > 1 MB + _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), largeData) + if err != nil { + t.Fatalf("Failed to insert large data: %v", err) + } + _, err = pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), largeData) + if err != nil { + t.Fatalf("Failed to insert large data: %v", err) + } + + nodesToCompare := []string{serviceN1, serviceN2} + tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + + err = tdTask.RunChecks(false) + if err == nil { + t.Fatal("RunChecks should fail for bytea data > 1MB, but it succeeded") + } + if !strings.Contains(err.Error(), "refusing to perform table-diff") { + t.Errorf("Error message should contain 'refusing to perform table-diff', but it was: %s", err.Error()) + } + if !strings.Contains(err.Error(), "is larger than 1 MB") { + t.Errorf("Error message should contain 'is larger than 1 MB', but it was: %s", err.Error()) + } + }) +}