Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 98 additions & 36 deletions db/helpers/sqlhelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ func AvgColumnSize(ctx context.Context, pool *pgxpool.Pool, schema, table, colum
return 0, err
}

schemaIdent := fmt.Sprintf("\"%s\"", schema)
tableIdent := fmt.Sprintf("\"%s\"", table)
colIdent := fmt.Sprintf("\"%s\"", column)
schemaIdent := fmt.Sprintf(`"%s"`, schema)
tableIdent := fmt.Sprintf(`"%s"`, table)
colIdent := fmt.Sprintf(`"%s"`, column)

query := fmt.Sprintf(
`SELECT COALESCE(AVG(pg_column_size(%s)), 0) FROM %s.%s`,
Expand All @@ -129,44 +129,68 @@ func GeneratePkeyOffsetsQuery(
samplePercent float64,
ntileCount int,
) (string, error) {
for _, ident := range append(keyColumns, schema, table) {
for _, ident := range append([]string{schema, table}, keyColumns...) {
if err := SanitiseIdentifier(ident); err != nil {
return "", fmt.Errorf("invalid identifier %q: %w", ident, err)
}
}
schemaIdent := fmt.Sprintf(`"%s"`, schema)
tableIdent := fmt.Sprintf(`"%s"`, table)

keyColsSelect := strings.Join(keyColumns, ",\n ")
keyColsOrder := strings.Join(keyColumns, ", ")
quotedKeyColsOriginal := make([]string, len(keyColumns))
for i, c := range keyColumns {
quotedKeyColsOriginal[i] = fmt.Sprintf(`"%s"`, c)
}

keyColsSelect := strings.Join(quotedKeyColsOriginal, ",\n ")
keyColsOrder := strings.Join(quotedKeyColsOriginal, ", ")

var descs []string
for _, c := range keyColumns {
descs = append(descs, fmt.Sprintf("%s DESC", c))
descs = append(descs, fmt.Sprintf(`"%s" DESC`, c))
}
keyColsOrderDesc := strings.Join(descs, ", ")

var firstSelects, lastSelects, firstTuples []string
for _, c := range keyColumns {
quotedCol := fmt.Sprintf(`"%s"`, c)
firstSelects = append(firstSelects,
fmt.Sprintf("(SELECT %s FROM first_row) AS %s", c, c))
fmt.Sprintf(`(SELECT %s FROM first_row) AS %s`, quotedCol, quotedCol))
lastSelects = append(lastSelects,
fmt.Sprintf("(SELECT %s FROM last_row) AS %s", c, c))
fmt.Sprintf(`(SELECT %s FROM last_row) AS %s`, quotedCol, quotedCol))
firstTuples = append(firstTuples,
fmt.Sprintf("(SELECT %s FROM first_row)", c))
fmt.Sprintf(`(SELECT %s FROM first_row)`, quotedCol))
}

var rangeStarts, rangeEnds, rangeOutputs []string
var rangeStarts, rangeEnds []string
for _, c := range keyColumns {
rangeStarts = append(rangeStarts, fmt.Sprintf("%s AS range_start_%s", c, c))
quotedCol := fmt.Sprintf(`"%s"`, c)
aliasStart := fmt.Sprintf(`range_start_%s`, c)
quotedAliasStart := fmt.Sprintf(`"%s"`, aliasStart)

aliasEnd := fmt.Sprintf(`range_end_%s`, c)
quotedAliasEnd := fmt.Sprintf(`"%s"`, aliasEnd)

rangeStarts = append(rangeStarts, fmt.Sprintf(`%s AS %s`, quotedCol, quotedAliasStart))
rangeEnds = append(rangeEnds, fmt.Sprintf(
"LEAD(%s) OVER (ORDER BY seq, %s) AS range_end_%s",
c, keyColsOrder, c,
`LEAD(%s) OVER (ORDER BY seq, %s) AS %s`,
quotedCol, keyColsOrder, quotedAliasEnd,
))
rangeOutputs = append(rangeOutputs,
fmt.Sprintf("range_start_%s,\n range_end_%s", c, c))
}

var startComponentCols []string
var endComponentCols []string
for _, c := range keyColumns {
aliasStart := fmt.Sprintf(`range_start_%s`, c)
quotedAliasStart := fmt.Sprintf(`"%s"`, aliasStart)
startComponentCols = append(startComponentCols, quotedAliasStart)

aliasEnd := fmt.Sprintf(`range_end_%s`, c)
quotedAliasEnd := fmt.Sprintf(`"%s"`, aliasEnd)
endComponentCols = append(endComponentCols, quotedAliasEnd)
}
selectOutputCols := append(startComponentCols, endComponentCols...)

data := map[string]any{
"SchemaIdent": schemaIdent,
"TableIdent": tableIdent,
Expand All @@ -181,7 +205,7 @@ func GeneratePkeyOffsetsQuery(
"FirstRowTupleSelects": strings.Join(firstTuples, ",\n "),
"RangeStartColumns": strings.Join(rangeStarts, ",\n "),
"RangeEndColumns": strings.Join(rangeEnds, ",\n "),
"RangeOutputColumns": strings.Join(rangeOutputs, ",\n "),
"RangeOutputColumns": strings.Join(selectOutputCols, ",\n "),
}

var buf bytes.Buffer
Expand All @@ -191,35 +215,73 @@ func GeneratePkeyOffsetsQuery(
return buf.String(), nil
}

func BlockHashSQL(schema, table string, cols []string, primaryKey string) (string, error) {
func BlockHashSQL(schema, table string, cols []string, primaryKeyCols []string) (string, error) {
if err := SanitiseIdentifier(schema); err != nil {
return "", err
}
if err := SanitiseIdentifier(table); err != nil {
return "", err
}
if err := SanitiseIdentifier(primaryKey); err != nil {
return "", err
}
for _, col := range cols {
if err := SanitiseIdentifier(col); err != nil {
return "", err
for _, pkCol := range primaryKeyCols {
if err := SanitiseIdentifier(pkCol); err != nil {
return "", fmt.Errorf("invalid primary key column identifier %q: %w", pkCol, err)
}
}
schemaIdent := fmt.Sprintf("\"%s\"", schema)
tableIdent := fmt.Sprintf("\"%s\"", table)
primaryIdent := fmt.Sprintf("\"%s\"", primaryKey)
var colIdents []string
for _, col := range cols {
colIdents = append(colIdents, fmt.Sprintf("\"%s\"", col))

schemaIdent := fmt.Sprintf(`"%s"`, schema)
tableIdent := fmt.Sprintf(`"%s"`, table)
tableAlias := "_tbl_"

quotedPKColIdents := make([]string, len(primaryKeyCols))
for i, pkCol := range primaryKeyCols {
quotedPKColIdents[i] = fmt.Sprintf(`"%s"`, pkCol)
}
pkOrderByStr := strings.Join(quotedPKColIdents, ", ")

pkComparisonExpression := ""
if len(primaryKeyCols) == 1 {
pkComparisonExpression = quotedPKColIdents[0]
} else {
pkComparisonExpression = fmt.Sprintf("ROW(%s)", strings.Join(quotedPKColIdents, ", "))
}

startPlaceholders := make([]string, len(primaryKeyCols))
for i := range primaryKeyCols {
startPlaceholders[i] = fmt.Sprintf("$%d", 2+i)
}
colsList := strings.Join(colIdents, ", ")
startValueExpression := ""
if len(primaryKeyCols) == 1 {
startValueExpression = startPlaceholders[0]
} else {
startValueExpression = fmt.Sprintf("ROW(%s)", strings.Join(startPlaceholders, ", "))
}

skipMaxCheckPlaceholderIndex := 2 + len(primaryKeyCols)

endPlaceholders := make([]string, len(primaryKeyCols))
for i := range primaryKeyCols {
endPlaceholders[i] = fmt.Sprintf("$%d", skipMaxCheckPlaceholderIndex+1+i)
}
endValueExpression := ""
if len(primaryKeyCols) == 1 {
endValueExpression = endPlaceholders[0]
} else {
endValueExpression = fmt.Sprintf("ROW(%s)", strings.Join(endPlaceholders, ", "))
}

query := fmt.Sprintf(
`SELECT encode(digest(COALESCE(string_agg(concat_ws('|', %s),'|' ORDER BY %s),'EMPTY_BLOCK'),'sha1'),'hex')
FROM %s.%s
WHERE ($1::boolean OR %s >= $2)
AND ($3::boolean OR %s < $4)`,
colsList, primaryIdent, schemaIdent, tableIdent, primaryIdent, primaryIdent,
`SELECT encode(digest(COALESCE(string_agg(%s::text, '|' ORDER BY %s), '[EMPTY_BLOCK]'), 'sha1'), 'hex')
FROM %s.%s AS %s
WHERE ($1::boolean OR %s >= %s)
AND ($%d::boolean OR %s < %s)`,
tableAlias,
pkOrderByStr,
schemaIdent, tableIdent, tableAlias,
pkComparisonExpression,
startValueExpression,
skipMaxCheckPlaceholderIndex,
pkComparisonExpression,
endValueExpression,
)
return query, nil
}
Loading