Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions db/helpers/sqlhelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func SanitiseIdentifier(input string) error {
return nil
}

func AvgColumnSize(ctx context.Context, db DBQuerier, schema, table, column string) (int64, error) {
func MaxColumnSize(ctx context.Context, db DBQuerier, schema, table, column string) (int64, error) {
if err := SanitiseIdentifier(schema); err != nil {
return 0, err
}
Expand All @@ -115,22 +115,22 @@ func AvgColumnSize(ctx context.Context, db DBQuerier, schema, table, column stri
colIdent := fmt.Sprintf(`"%s"`, column)

query := fmt.Sprintf(
`SELECT COALESCE(AVG(pg_column_size(%s)), 0) FROM %s.%s`,
`SELECT COALESCE(MAX(octet_length(%s))::bigint, 0) FROM %s.%s`,
colIdent, schemaIdent, tableIdent,
)

var avgSize int64
if err := db.QueryRow(ctx, query).Scan(&avgSize); err != nil {
var maxSize int64
if err := db.QueryRow(ctx, query).Scan(&maxSize); err != nil {
return 0, fmt.Errorf(
"AvgColumnSize query failed for %s.%s.%s: %w",
"MaxColumnSize query failed for %s.%s.%s: %w",
schema,
table,
column,
err,
)
}

return avgSize, nil
return maxSize, nil
}

func GeneratePkeyOffsetsQuery(
Expand Down
44 changes: 24 additions & 20 deletions internal/core/table_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,17 +573,9 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error {
hostIP, _ := nodeInfo["PublicIP"].(string)
user, _ := nodeInfo["DBUser"].(string)

var port float64
portVal, ok := nodeInfo["Port"]
if ok {
switch v := portVal.(type) {
case string:
port, _ = strconv.ParseFloat(v, 64)
case float64:
port = v
}
} else {
port = 5432
port, ok := nodeInfo["Port"].(string)
if !ok {
port = "5432"
}

if !Contains(t.NodeList, hostname) {
Expand Down Expand Up @@ -629,7 +621,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error {
return fmt.Errorf("failed to get column types for table %s on node %s: %w", table, hostname, err)
}

colTypesKey := fmt.Sprintf("%s:%d", hostIP, int(port))
colTypesKey := fmt.Sprintf("%s:%s", hostIP, port)

if t.ColTypes == nil {
t.ColTypes = make(map[string]map[string]string)
Expand All @@ -655,7 +647,7 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error {
user, strings.Join(missingPrivs, ", "), schema, table, hostname)
}

hostMap[hostIP+":"+fmt.Sprint(int(port))] = hostname
hostMap[hostIP+":"+port] = hostname

if t.TableFilter != "" {
viewName := fmt.Sprintf("%s_%s_filtered", t.TaskID, table)
Expand Down Expand Up @@ -755,18 +747,27 @@ func (t *TableDiffTask) CheckColumnSize() error {
var pool *pgxpool.Pool
for _, nodeInfo := range t.ClusterNodes {
nodeHost, _ := nodeInfo["PublicIP"].(string)
nodePort, _ := nodeInfo["Port"].(float64)

var nodePort float64
if nodePortVal, ok := nodeInfo["Port"]; ok {
switch v := nodePortVal.(type) {
case string:
nodePort, _ = strconv.ParseFloat(v, 64)
case float64:
nodePort = v
}
}

if nodePort == 0 {
nodePort = 5432
}

if nodeHost == host && int(nodePort) == port {
conn, err := auth.GetClusterNodeConnection(nodeInfo, t.ClientRole)
var err error
pool, err = auth.GetClusterNodeConnection(nodeInfo, t.ClientRole)
if err != nil {
return fmt.Errorf("failed to connect to node %s:%d: %w", host, port, err)
}
defer conn.Close()
pool = conn
break
}
}
Expand All @@ -780,17 +781,20 @@ func (t *TableDiffTask) CheckColumnSize() error {
continue
}

avgSize, err := helpers.AvgColumnSize(context.Background(), pool, t.Schema, t.Table, colName)
logger.Debug("Column %s of table %s.%s has average size %d", colName, t.Schema, t.Table, avgSize)
maxSize, err := helpers.MaxColumnSize(context.Background(), pool, t.Schema, t.Table, colName)
logger.Debug("Column %s of table %s.%s has max size %d", colName, t.Schema, t.Table, maxSize)
if err != nil {
pool.Close()
return fmt.Errorf("failed to check size of bytea column %s: %w", colName, err)
}

if avgSize > 1000000 {
if maxSize > 1000000 {
pool.Close()
return fmt.Errorf("refusing to perform table-diff. Data in column %s of table %s.%s is larger than 1 MB",
colName, t.Schema, t.Table)
}
}
pool.Close()
}

return nil
Expand Down
91 changes: 91 additions & 0 deletions tests/integration/table_diff_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -900,3 +900,94 @@ func TestTableDiff_TableFiltering(t *testing.T) {

log.Println("TestTableDiff_TableFiltering completed.")
}

func TestTableDiff_ByteaColumnSizeCheck(t *testing.T) {
ctx := context.Background()
tableName := "bytea_size_test"
qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName)

createTableSQL := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id INT PRIMARY KEY,
data BYTEA
);`, qualifiedTableName)

// Create table on both nodes and add cleanup to drop it
for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} {
_, err := pool.Exec(ctx, createTableSQL)
if err != nil {
t.Fatalf("Failed to create test table %s: %v", qualifiedTableName, err)
}
}
t.Cleanup(func() {
for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} {
_, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", qualifiedTableName))
if err != nil {
t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err)
}
}
})

// --- Test Case 1: Data < 1MB (should pass) ---
t.Run("DataUnder1MB", func(t *testing.T) {
// Truncate before run
for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} {
_, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName))
if err != nil {
t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err)
}
}

smallData := make([]byte, 500*1024) // 500 KB
_, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), smallData)
if err != nil {
t.Fatalf("Failed to insert small data: %v", err)
}
_, err = pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), smallData)
if err != nil {
t.Fatalf("Failed to insert small data: %v", err)
}

nodesToCompare := []string{serviceN1, serviceN2}
tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare)

err = tdTask.RunChecks(false)
if err != nil {
t.Errorf("RunChecks should succeed for bytea data < 1MB, but got error: %v", err)
}
})

// --- Test Case 2: Data > 1MB (should fail) ---
t.Run("DataOver1MB", func(t *testing.T) {
// Truncate before run
for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} {
_, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName))
if err != nil {
t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err)
}
}
largeData := make([]byte, 1024*1024+1) // > 1 MB
_, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), largeData)
if err != nil {
t.Fatalf("Failed to insert large data: %v", err)
}
_, err = pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, $1)", qualifiedTableName), largeData)
if err != nil {
t.Fatalf("Failed to insert large data: %v", err)
}

nodesToCompare := []string{serviceN1, serviceN2}
tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare)

err = tdTask.RunChecks(false)
if err == nil {
t.Fatal("RunChecks should fail for bytea data > 1MB, but it succeeded")
}
if !strings.Contains(err.Error(), "refusing to perform table-diff") {
t.Errorf("Error message should contain 'refusing to perform table-diff', but it was: %s", err.Error())
}
if !strings.Contains(err.Error(), "is larger than 1 MB") {
t.Errorf("Error message should contain 'is larger than 1 MB', but it was: %s", err.Error())
}
})
}