/
postgres.go
665 lines (588 loc) · 23.5 KB
/
postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
/*
Copyright (c) YugabyteDB, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package srcdb
import (
"context"
"database/sql"
"fmt"
"net/url"
"os/exec"
"path/filepath"
"regexp"
"strings"
"github.com/google/uuid"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v5/pgconn"
"github.com/mcuadros/go-version"
"github.com/samber/lo"
log "github.com/sirupsen/logrus"
"github.com/yugabyte/yb-voyager/yb-voyager/src/datafile"
"github.com/yugabyte/yb-voyager/yb-voyager/src/utils"
"github.com/yugabyte/yb-voyager/yb-voyager/src/utils/sqlname"
)
const PG_COMMAND_VERSION string = "14.0"
const FETCH_COLUMN_SEQUENCES_QUERY_TEMPLATE = `SELECT
a.attname AS column_name,
COALESCE(seq.relname, '') AS sequence_name,
COALESCE(ns.nspname, '') AS schema_name
FROM pg_class AS t
JOIN pg_attribute AS a ON a.attrelid = t.oid
JOIN pg_namespace AS tn ON tn.oid = t.relnamespace
LEFT JOIN pg_attrdef AS ad ON ad.adrelid = t.oid AND ad.adnum = a.attnum
LEFT JOIN pg_depend AS d ON d.objid = ad.oid
LEFT JOIN pg_class AS seq ON seq.oid = d.refobjid
LEFT JOIN pg_namespace AS ns ON ns.oid = seq.relnamespace
WHERE
tn.nspname = '%s' -- schema name
AND t.relname = '%s' -- table name
AND a.attnum > 0
AND NOT a.attisdropped
AND t.relkind IN ('r', 'P')
AND seq.relkind = 'S';`
type PostgreSQL struct {
source *Source
db *pgx.Conn
}
func newPostgreSQL(s *Source) *PostgreSQL {
return &PostgreSQL{source: s}
}
func (pg *PostgreSQL) Connect() error {
db, err := pgx.Connect(context.Background(), pg.getConnectionUri())
pg.db = db
return err
}
func (pg *PostgreSQL) Disconnect() {
if pg.db == nil {
log.Infof("No connection to the source database to close")
return
}
err := pg.db.Close(context.Background())
if err != nil {
log.Infof("Failed to close connection to the source database: %s", err)
}
}
func (pg *PostgreSQL) CheckRequiredToolsAreInstalled() {
checkTools("strings")
}
func (pg *PostgreSQL) GetTableRowCount(tableName string) int64 {
// new conn to avoid conn busy err as multiple parallel(and time-taking) queries possible
conn, err := pgx.Connect(context.Background(), pg.getConnectionUri())
if err != nil {
utils.ErrExit("Failed to connect to the source database for table row count: %s", err)
}
defer conn.Close(context.Background())
var rowCount int64
query := fmt.Sprintf("select count(*) from %s", tableName)
log.Infof("Querying row count of table %q", tableName)
err = conn.QueryRow(context.Background(), query).Scan(&rowCount)
if err != nil {
utils.ErrExit("Failed to query %q for row count of %q: %s", query, tableName, err)
}
log.Infof("Table %q has %v rows.", tableName, rowCount)
return rowCount
}
func (pg *PostgreSQL) GetTableApproxRowCount(tableName *sqlname.SourceName) int64 {
var approxRowCount sql.NullInt64 // handles case: value of the row is null, default for int64 is 0
query := fmt.Sprintf("SELECT reltuples::bigint FROM pg_class "+
"where oid = '%s'::regclass", tableName.Qualified.MinQuoted)
log.Infof("Querying '%s' approx row count of table %q", query, tableName.String())
err := pg.db.QueryRow(context.Background(), query).Scan(&approxRowCount)
if err != nil {
utils.ErrExit("Failed to query %q for approx row count of %q: %s", query, tableName.String(), err)
}
log.Infof("Table %q has approx %v rows.", tableName.String(), approxRowCount)
return approxRowCount.Int64
}
func (pg *PostgreSQL) GetVersion() string {
var version string
query := "SELECT setting from pg_settings where name = 'server_version'"
err := pg.db.QueryRow(context.Background(), query).Scan(&version)
if err != nil {
utils.ErrExit("run query %q on source: %s", query, err)
}
pg.source.DBVersion = version
return version
}
func (pg *PostgreSQL) checkSchemasExists() []string {
list := strings.Split(pg.source.Schema, "|")
var trimmedList []string
for _, schema := range list {
if utils.IsQuotedString(schema) {
schema = strings.Trim(schema, `"`)
}
trimmedList = append(trimmedList, schema)
}
querySchemaList := "'" + strings.Join(trimmedList, "','") + "'"
chkSchemaExistsQuery := fmt.Sprintf(`SELECT schema_name
FROM information_schema.schemata where schema_name IN (%s);`, querySchemaList)
rows, err := pg.db.Query(context.Background(), chkSchemaExistsQuery)
if err != nil {
utils.ErrExit("error in querying(%q) source database for checking mentioned schema(s) present or not: %v\n", chkSchemaExistsQuery, err)
}
var listOfSchemaPresent []string
var tableSchemaName string
for rows.Next() {
err = rows.Scan(&tableSchemaName)
if err != nil {
utils.ErrExit("error in scanning query rows for schema names: %v\n", err)
}
listOfSchemaPresent = append(listOfSchemaPresent, tableSchemaName)
}
defer rows.Close()
schemaNotPresent := utils.SetDifference(trimmedList, listOfSchemaPresent)
if len(schemaNotPresent) > 0 {
utils.ErrExit("Following schemas are not present in source database %v, please provide a valid schema list.\n", schemaNotPresent)
}
return trimmedList
}
func (pg *PostgreSQL) GetAllTableNames() []*sqlname.SourceName {
schemaList := pg.checkSchemasExists()
querySchemaList := "'" + strings.Join(schemaList, "','") + "'"
query := fmt.Sprintf(`SELECT table_schema, table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE' AND
table_schema IN (%s);`, querySchemaList)
rows, err := pg.db.Query(context.Background(), query)
if err != nil {
utils.ErrExit("error in querying(%q) source database for table names: %v\n", query, err)
}
defer rows.Close()
var tableNames []*sqlname.SourceName
var tableName, tableSchema string
for rows.Next() {
err = rows.Scan(&tableSchema, &tableName)
if err != nil {
utils.ErrExit("error in scanning query rows for table names: %v\n", err)
}
tableName = fmt.Sprintf("\"%s\"", tableName)
tableNames = append(tableNames, sqlname.NewSourceName(tableSchema, tableName))
}
log.Infof("Query found %d tables in the source db: %v", len(tableNames), tableNames)
return tableNames
}
func (pg *PostgreSQL) getConnectionUri() string {
source := pg.source
if source.Uri != "" {
return source.Uri
}
hostAndPort := fmt.Sprintf("%s:%d", source.Host, source.Port)
sourceUrl := &url.URL{
Scheme: "postgresql",
User: url.UserPassword(source.User, source.Password),
Host: hostAndPort,
Path: source.DBName,
RawQuery: generateSSLQueryStringIfNotExists(source),
}
source.Uri = sourceUrl.String()
return source.Uri
}
func (pg *PostgreSQL) getConnectionUriWithoutPassword() string {
source := pg.source
hostAndPort := fmt.Sprintf("%s:%d", source.Host, source.Port)
sourceUrl := &url.URL{
Scheme: "postgresql",
User: url.User(source.User),
Host: hostAndPort,
Path: source.DBName,
RawQuery: generateSSLQueryStringIfNotExists(source),
}
return sourceUrl.String()
}
func (pg *PostgreSQL) ExportSchema(exportDir string) {
pg.checkSchemasExists()
pgdumpExtractSchema(pg.source, pg.getConnectionUriWithoutPassword(), exportDir)
}
func (pg *PostgreSQL) GetIndexesInfo() []utils.IndexInfo {
return nil
}
func (pg *PostgreSQL) ExportData(ctx context.Context, exportDir string, tableList []*sqlname.SourceName, quitChan chan bool, exportDataStart, exportSuccessChan chan bool, tablesColumnList map[*sqlname.SourceName][]string, snapshotName string) {
pgdumpExportDataOffline(ctx, pg.source, pg.getConnectionUriWithoutPassword(), exportDir, tableList, quitChan, exportDataStart, exportSuccessChan, snapshotName)
}
func (pg *PostgreSQL) ExportDataPostProcessing(exportDir string, tablesProgressMetadata map[string]*utils.TableProgressMetadata) {
renameDataFiles(tablesProgressMetadata)
dfd := datafile.Descriptor{
FileFormat: datafile.TEXT,
DataFileList: getExportedDataFileList(tablesProgressMetadata),
Delimiter: "\t",
HasHeader: false,
ExportDir: exportDir,
NullString: `\N`,
TableNameToExportedColumns: pg.getExportedColumnsMap(exportDir, tablesProgressMetadata),
}
dfd.Save()
}
func (pg *PostgreSQL) getExportedColumnsMap(
exportDir string, tablesMetadata map[string]*utils.TableProgressMetadata) map[string][]string {
result := make(map[string][]string)
for _, tableMetadata := range tablesMetadata {
// TODO: Use tableMetadata.TableName instead of parsing the file name.
// We need a new method in sqlname.SourceName that returns MaybeQuoted and MaybeQualified names.
tableName := strings.TrimSuffix(filepath.Base(tableMetadata.FinalFilePath), "_data.sql")
result[tableName] = pg.getExportedColumnsListForTable(exportDir, tableName)
}
return result
}
func (pg *PostgreSQL) getExportedColumnsListForTable(exportDir, tableName string) []string {
var columnsList []string
var re *regexp.Regexp
if len(strings.Split(tableName, ".")) == 1 {
// happens only when table is in public schema, use public schema with table name for regexp
re = regexp.MustCompile(fmt.Sprintf(`(?i)COPY public.%s[\s]+\((.*)\) FROM STDIN`, tableName))
} else {
re = regexp.MustCompile(fmt.Sprintf(`(?i)COPY %s[\s]+\((.*)\) FROM STDIN`, tableName))
}
tocFilePath := filepath.Join(exportDir, "data", "toc.dat")
err := utils.ForEachMatchingLineInFile(tocFilePath, re, func(matches []string) bool {
columnsList = strings.Split(matches[1], ",")
for i, column := range columnsList {
columnsList[i] = strings.TrimSpace(column)
}
return false // stop reading file
})
if err != nil {
utils.ErrExit("error in reading toc file: %v\n", err)
}
log.Infof("columns list for table %s: %v", tableName, columnsList)
return columnsList
}
// Given a PG command name ("pg_dump", "pg_restore"), find absolute path of
// the executable file having version >= `PG_COMMAND_VERSION`.
func GetAbsPathOfPGCommand(cmd string) (string, error) {
paths, err := findAllExecutablesInPath(cmd)
if err != nil {
err = fmt.Errorf("error in finding executables: %w", err)
return "", err
}
if len(paths) == 0 {
err = fmt.Errorf("the command %v is not installed", cmd)
return "", err
}
for _, path := range paths {
cmd := exec.Command(path, "--version")
stdout, err := cmd.Output()
if err != nil {
err = fmt.Errorf("error in finding version of %v from path %v: %w", cmd, path, err)
return "", err
}
// example output centos: pg_restore (PostgreSQL) 14.5
// example output Ubuntu: pg_dump (PostgreSQL) 14.5 (Ubuntu 14.5-1.pgdg22.04+1)
currVersion := strings.Fields(string(stdout))[2]
if version.CompareSimple(currVersion, PG_COMMAND_VERSION) >= 0 {
return path, nil
}
}
err = fmt.Errorf("could not find %v with version greater than or equal to %v", cmd, PG_COMMAND_VERSION)
return "", err
}
// GetAllSequences returns all the sequence names in the database for the given schema list
func (pg *PostgreSQL) GetAllSequences() []string {
schemaList := pg.checkSchemasExists()
querySchemaList := "'" + strings.Join(schemaList, "','") + "'"
var sequenceNames []string
query := fmt.Sprintf(`SELECT sequence_schema, sequence_name FROM information_schema.sequences where sequence_schema IN (%s);`, querySchemaList)
rows, err := pg.db.Query(context.Background(), query)
if err != nil {
utils.ErrExit("error in querying(%q) source database for sequence names: %v\n", query, err)
}
defer rows.Close()
var sequenceName, sequenceSchema string
for rows.Next() {
err = rows.Scan(&sequenceSchema, &sequenceName)
if err != nil {
utils.ErrExit("error in scanning query rows for sequence names: %v\n", err)
}
sequenceNames = append(sequenceNames, fmt.Sprintf(`%s."%s"`, sequenceSchema, sequenceName))
}
return sequenceNames
}
func (pg *PostgreSQL) GetCharset() (string, error) {
query := fmt.Sprintf("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = '%s';", pg.source.DBName)
encoding := ""
err := pg.db.QueryRow(context.Background(), query).Scan(&encoding)
if err != nil {
return "", fmt.Errorf("error in querying database encoding: %w", err)
}
return encoding, nil
}
func (pg *PostgreSQL) FilterUnsupportedTables(tableList []*sqlname.SourceName, useDebezium bool) ([]*sqlname.SourceName, []*sqlname.SourceName) {
return tableList, nil
}
func (pg *PostgreSQL) FilterEmptyTables(tableList []*sqlname.SourceName) ([]*sqlname.SourceName, []*sqlname.SourceName) {
var nonEmptyTableList, emptyTableList []*sqlname.SourceName
for _, tableName := range tableList {
query := fmt.Sprintf(`SELECT false FROM %s LIMIT 1;`, tableName.Qualified.MinQuoted)
var empty bool
err := pg.db.QueryRow(context.Background(), query).Scan(&empty)
if err != nil {
if err == pgx.ErrNoRows {
empty = true
} else {
utils.ErrExit("error in querying table %v: %v", tableName, err)
}
}
if !empty {
nonEmptyTableList = append(nonEmptyTableList, tableName)
} else {
emptyTableList = append(emptyTableList, tableName)
}
}
return nonEmptyTableList, emptyTableList
}
func (pg *PostgreSQL) GetTableColumns(tableName *sqlname.SourceName) ([]string, []string, []string) {
return nil, nil, nil
}
func (pg *PostgreSQL) GetColumnsWithSupportedTypes(tableList []*sqlname.SourceName, useDebezium bool, _ bool) (map[*sqlname.SourceName][]string, []string) {
return nil, nil
}
func (pg *PostgreSQL) ParentTableOfPartition(table *sqlname.SourceName) string {
var parentTable string
// For this query in case of case sensitive tables, minquoting is required
query := fmt.Sprintf(`SELECT inhparent::pg_catalog.regclass
FROM pg_catalog.pg_class c JOIN pg_catalog.pg_inherits ON c.oid = inhrelid
WHERE c.oid = '%s'::regclass::oid`, table.Qualified.MinQuoted)
err := pg.db.QueryRow(context.Background(), query).Scan(&parentTable)
if err != pgx.ErrNoRows && err != nil {
utils.ErrExit("Error in query=%s for parent tablename of table=%s: %v", query, table, err)
}
return parentTable
}
func (pg *PostgreSQL) GetColumnToSequenceMap(tableList []*sqlname.SourceName) map[string]string {
columnToSequenceMap := make(map[string]string)
for _, table := range tableList {
// query to find out column name vs sequence name for a table
// this query also covers the case of identity columns
query := fmt.Sprintf(FETCH_COLUMN_SEQUENCES_QUERY_TEMPLATE, table.SchemaName.Unquoted, table.ObjectName.Unquoted)
var columeName, sequenceName, schemaName string
rows, err := pg.db.Query(context.Background(), query)
if err != nil {
log.Infof("Query to find column to sequence mapping: %s", query)
utils.ErrExit("Error in querying for sequences in table=%s: %v", table, err)
}
for rows.Next() {
err := rows.Scan(&columeName, &sequenceName, &schemaName)
if err != nil {
utils.ErrExit("Error in scanning for sequences in table=%s: %v", table, err)
}
qualifiedColumnName := fmt.Sprintf("%s.%s", table.Qualified.Unquoted, columeName)
// quoting sequence name as it can be case sensitive - required during import data restore sequences
columnToSequenceMap[qualifiedColumnName] = fmt.Sprintf(`%s."%s"`, schemaName, sequenceName)
}
}
return columnToSequenceMap
}
func generateSSLQueryStringIfNotExists(s *Source) string {
if s.Uri == "" {
SSLQueryString := ""
if s.SSLQueryString == "" {
if s.SSLMode == "disable" || s.SSLMode == "allow" || s.SSLMode == "prefer" || s.SSLMode == "require" || s.SSLMode == "verify-ca" || s.SSLMode == "verify-full" {
SSLQueryString = "sslmode=" + s.SSLMode
if s.SSLMode == "require" || s.SSLMode == "verify-ca" || s.SSLMode == "verify-full" {
SSLQueryString = fmt.Sprintf("sslmode=%s", s.SSLMode)
if s.SSLCertPath != "" {
SSLQueryString += "&sslcert=" + s.SSLCertPath
}
if s.SSLKey != "" {
SSLQueryString += "&sslkey=" + s.SSLKey
}
if s.SSLRootCert != "" {
SSLQueryString += "&sslrootcert=" + s.SSLRootCert
}
if s.SSLCRL != "" {
SSLQueryString += "&sslcrl=" + s.SSLCRL
}
}
} else {
utils.ErrExit("Invalid sslmode: %q", s.SSLMode)
}
} else {
SSLQueryString = s.SSLQueryString
}
return SSLQueryString
} else {
return ""
}
}
func (pg *PostgreSQL) GetServers() []string {
return []string{pg.source.Host}
}
func (pg *PostgreSQL) GetPartitions(tableName *sqlname.SourceName) []*sqlname.SourceName {
partitions := make([]*sqlname.SourceName, 0)
query := fmt.Sprintf(`SELECT
nmsp_child.nspname AS child_schema,
child.relname AS child
FROM pg_inherits
JOIN pg_class parent ON pg_inherits.inhparent = parent.oid
JOIN pg_class child ON pg_inherits.inhrelid = child.oid
JOIN pg_namespace nmsp_parent ON nmsp_parent.oid = parent.relnamespace
JOIN pg_namespace nmsp_child ON nmsp_child.oid = child.relnamespace
WHERE parent.relname='%s' AND nmsp_parent.nspname = '%s' `, tableName.ObjectName.Unquoted, tableName.SchemaName.Unquoted)
rows, err := pg.db.Query(context.Background(), query)
if err != nil {
log.Errorf("failed to list partitions of table %s: query = [ %s ], error = %s", tableName, query, err)
utils.ErrExit("failed to find the partitions for table %s:", tableName, err)
}
defer rows.Close()
for rows.Next() {
var childSchema, childTable string
err := rows.Scan(&childSchema, &childTable)
if err != nil {
utils.ErrExit("Error in scanning for child partitions of table=%s: %v", tableName, err)
}
partitions = append(partitions, sqlname.NewSourceName(childSchema, childTable))
}
if rows.Err() != nil {
utils.ErrExit("Error in scanning for child partitions of table=%s: %v", tableName, rows.Err())
}
return partitions
}
func (pg *PostgreSQL) GetTableToUniqueKeyColumnsMap(tableList []*sqlname.SourceName) (map[string][]string, error) {
log.Infof("getting unique key columns for tables: %v", tableList)
result := make(map[string][]string)
var querySchemaList, queryTableList []string
for i := 0; i < len(tableList); i++ {
schema, table := tableList[i].SchemaName.Unquoted, tableList[i].ObjectName.Unquoted
querySchemaList = append(querySchemaList, schema)
queryTableList = append(queryTableList, table)
}
querySchemaList = lo.Uniq(querySchemaList)
query := fmt.Sprintf(ybQueryTmplForUniqCols, strings.Join(querySchemaList, ","), strings.Join(queryTableList, ","))
log.Infof("query to get unique key columns: %s", query)
rows, err := pg.db.Query(context.Background(), query)
if err != nil {
return nil, fmt.Errorf("querying unique key columns: %w", err)
}
defer rows.Close()
for rows.Next() {
var schemaName, tableName, colName string
err := rows.Scan(&schemaName, &tableName, &colName)
if err != nil {
return nil, fmt.Errorf("scanning row for unique key column name: %w", err)
}
if schemaName != "public" {
tableName = fmt.Sprintf("%s.%s", schemaName, tableName)
}
result[tableName] = append(result[tableName], colName)
}
err = rows.Err()
if err != nil {
return nil, fmt.Errorf("error iterating over rows for unique key columns: %w", err)
}
log.Infof("unique key columns for tables: %v", result)
return result, nil
}
func (pg *PostgreSQL) ClearMigrationState(migrationUUID uuid.UUID, exportDir string) error {
log.Infof("ClearMigrationState not implemented yet for PostgreSQL")
return nil
}
func (pg *PostgreSQL) GetReplicationConnection() (*pgconn.PgConn, error) {
return pgconn.Connect(context.Background(), pg.getConnectionUri()+"&replication=database")
}
func (pg *PostgreSQL) CreateLogicalReplicationSlot(conn *pgconn.PgConn, replicationSlotName string, dropIfAlreadyExists bool) (*pglogrepl.CreateReplicationSlotResult, error) {
if dropIfAlreadyExists {
log.Infof("dropping replication slot %s if already exists", replicationSlotName)
err := pg.DropLogicalReplicationSlot(conn, replicationSlotName)
if err != nil {
return nil, err
}
}
log.Infof("creating replication slot %s", replicationSlotName)
res, err := pglogrepl.CreateReplicationSlot(context.Background(), conn, replicationSlotName, "pgoutput",
pglogrepl.CreateReplicationSlotOptions{Mode: pglogrepl.LogicalReplication})
if err != nil {
return nil, fmt.Errorf("create replication slot: %v", err)
}
return &res, nil
}
func (pg *PostgreSQL) DropLogicalReplicationSlot(conn *pgconn.PgConn, replicationSlotName string) error {
var err error
if conn == nil {
conn, err = pg.GetReplicationConnection()
if err != nil {
utils.ErrExit("failed to create replication connection for dropping replication slot: %s", err)
}
defer conn.Close(context.Background())
}
log.Infof("dropping replication slot: %s", replicationSlotName)
err = pglogrepl.DropReplicationSlot(context.Background(), conn, replicationSlotName, pglogrepl.DropReplicationSlotOptions{})
if err != nil {
// ignore "does not exist" error while dropping replication slot
if !strings.Contains(err.Error(), "does not exist") {
return fmt.Errorf("delete existing replication slot(%s): %v", replicationSlotName, err)
}
}
return nil
}
func (pg *PostgreSQL) CreatePublication(conn *pgconn.PgConn, publicationName string, tableList []*sqlname.SourceName, dropIfAlreadyExists bool) error {
if dropIfAlreadyExists {
err := pg.DropPublication(publicationName)
if err != nil {
return fmt.Errorf("drop publication: %v", err)
}
}
tablelistQualifiedQuoted := []string{}
for _, tableName := range tableList {
tablelistQualifiedQuoted = append(tablelistQualifiedQuoted, tableName.Qualified.Quoted)
}
stmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s;", publicationName, strings.Join(tablelistQualifiedQuoted, ","))
result := conn.Exec(context.Background(), stmt)
_, err := result.ReadAll()
if err != nil {
return fmt.Errorf("create publication with stmt %s: %v", err, stmt)
}
log.Infof("created publication with stmt %s", stmt)
return nil
}
func (pg *PostgreSQL) DropPublication(publicationName string) error {
conn, err := pgx.Connect(context.Background(), pg.getConnectionUri())
if err != nil {
utils.ErrExit("failed to connect to the source database for dropping publication: %s", err)
}
defer conn.Close(context.Background())
log.Infof("dropping publication: %s", publicationName)
res, err := conn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS %s", publicationName))
log.Infof("drop publication result: %v", res)
if err != nil {
return fmt.Errorf("drop publication(%s): %v", publicationName, err)
}
return nil
}
var PG_QUERY_TO_CHECK_IF_TABLE_HAS_PK = `SELECT nspname AS schema_name, relname AS table_name, COUNT(conname) AS pk_count
FROM pg_class c
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN pg_constraint con ON con.conrelid = c.oid AND con.contype = 'p'
GROUP BY schema_name, table_name HAVING nspname IN (%s);`
func (pg *PostgreSQL) GetNonPKTables() ([]string, error) {
var nonPKTables []string
schemaList := strings.Split(pg.source.Schema, "|")
querySchemaList := "'" + strings.Join(schemaList, "','") + "'"
query := fmt.Sprintf(PG_QUERY_TO_CHECK_IF_TABLE_HAS_PK, querySchemaList)
rows, err := pg.db.Query(context.Background(), query)
if err != nil {
return nil, fmt.Errorf("error in querying(%q) source database for primary key: %v", query, err)
}
defer rows.Close()
for rows.Next() {
var schemaName, tableName string
var pkCount int
err := rows.Scan(&schemaName, &tableName, &pkCount)
if err != nil {
return nil, fmt.Errorf("error in scanning query rows for primary key: %v", err)
}
table := sqlname.NewSourceName(schemaName, fmt.Sprintf(`"%s"`, tableName))
if pkCount == 0 {
nonPKTables = append(nonPKTables, table.Qualified.MinQuoted)
}
}
return nonPKTables, nil
}