diff --git a/pkg/pgclient/client.go b/pkg/pgclient/client.go index 5ca4a3fb14..f7e4dca033 100644 --- a/pkg/pgclient/client.go +++ b/pkg/pgclient/client.go @@ -69,38 +69,26 @@ func NewClient(cfg *Config, mt tenancy.Authorizer, schemaLocker LockFunc, readOn } func getPgConfig(cfg *Config) (*pgxpool.Config, int, error) { - connectionStr, err := cfg.GetConnectionStr() - if err != nil { - return nil, 0, err - } - minConnections, maxConnections, numCopiers, err := cfg.GetNumConnections() if err != nil { log.Error("msg", "configuring number of connections", "err", util.MaskPassword(err.Error())) return nil, numCopiers, err } - - var ( - pgConfig *pgxpool.Config - connectionArgsFmt string - ) - if cfg.DbUri == defaultDBUri { - connectionArgsFmt = "%s pool_max_conns=%d pool_min_conns=%d statement_cache_capacity=%d" - } else { - connectionArgsFmt = "%s&pool_max_conns=%d&pool_min_conns=%d&statement_cache_capacity=%d" - } - - // Using the PGX default of 512 for statement cache capacity. - statementCacheCapacity := 512 - connectionStringWithArgs := fmt.Sprintf(connectionArgsFmt, connectionStr, maxConnections, minConnections, statementCacheCapacity) - pgConfig, err = pgxpool.ParseConfig(connectionStringWithArgs) + connectionStr := cfg.GetConnectionStr() + pgConfig, err := pgxpool.ParseConfig(connectionStr) if err != nil { log.Error("msg", "configuring connection", "err", util.MaskPassword(err.Error())) return nil, numCopiers, err } + // Configure the number of connections and statement cache capacity. + pgConfig.MinConns = int32(minConnections) + pgConfig.MaxConns = int32(maxConnections) + var statementCacheLog string if cfg.EnableStatementsCache { + // Using the PGX default of 512 for statement cache capacity. + statementCacheCapacity := 512 pgConfig.AfterRelease = observeStatementCacheState statementCacheEnabled.Set(1) statementCacheCap.Set(float64(statementCacheCapacity)) diff --git a/pkg/pgclient/config.go b/pkg/pgclient/config.go index 75ac71e404..f0a155ac4d 100644 --- a/pkg/pgclient/config.go +++ b/pkg/pgclient/config.go @@ -89,24 +89,50 @@ func ParseFlags(fs *flag.FlagSet, cfg *Config) *Config { } func Validate(cfg *Config, lcfg limits.Config) error { + if err := cfg.validateConnectionSettings(); err != nil { + return err + } return cache.Validate(&cfg.CacheConfig, lcfg) } -// GetConnectionStr returns a Postgres connection string -func (cfg *Config) GetConnectionStr() (string, error) { - // if DBURI is default build the connStr with DB flags - // else as DBURI isn't default check if db flags are default if we notice DBURI + DB flags not default give an error - // Now as DBURI isn't default and DB flags are default build a connStr for DBURI. +// validateConnectionSettings checks that we are not using both a DB URI and +// DB configuration flags +func (cfg Config) validateConnectionSettings() error { + // If we are using DB URI, nothing to check. if cfg.DbUri == defaultDBUri { - return fmt.Sprintf("application_name=%s host=%v port=%v user=%v dbname=%v password='%v' sslmode=%v connect_timeout=%d", - cfg.AppName, cfg.Host, cfg.Port, cfg.User, cfg.Database, cfg.Password, cfg.SslMode, int(cfg.DbConnectionTimeout.Seconds())), nil - } else if cfg.AppName != DefaultApp || cfg.Database != defaultDBName || cfg.Host != defaultDBHost || cfg.Port != defaultDBPort || - cfg.User != defaultDBUser || cfg.Password != defaultDBPassword || cfg.SslMode != defaultSSLMode || + return nil + } + + // If using DB URI, check if any DB flags are supplied. + if cfg.AppName != DefaultApp || + cfg.Database != defaultDBName || + cfg.Host != defaultDBHost || + cfg.Port != defaultDBPort || + cfg.User != defaultDBUser || + cfg.Password != defaultDBPassword || + cfg.SslMode != defaultSSLMode || cfg.DbConnectionTimeout != defaultConnectionTime { - return "", excessDBFlagsError + return excessDBFlagsError } - return cfg.DbUri, nil + return nil +} + +// GetConnectionStr returns a Postgres connection string +func (cfg *Config) GetConnectionStr() string { + // If DB URI is not supplied, generate one from DB flags. + if cfg.DbUri == defaultDBUri { + return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?application_name=%s&sslmode=%v&connect_timeout=%d", + cfg.User, + cfg.Password, + cfg.Host, + cfg.Port, + cfg.Database, + cfg.AppName, + cfg.SslMode, + int(cfg.DbConnectionTimeout.Seconds())) + } + return cfg.DbUri } func (cfg *Config) GetNumConnections() (min int, max int, numCopiers int, err error) { @@ -117,10 +143,7 @@ func (cfg *Config) GetNumConnections() (min int, max int, numCopiers int, err er perProc := cfg.WriteConnectionsPerProc max = cfg.MaxConnections if max < 1 { - connStr, err := cfg.GetConnectionStr() - if err != nil { - return 0, 0, 0, err - } + connStr := cfg.GetConnectionStr() conn, err := pgx.Connect(context.Background(), connStr) if err != nil { return 0, 0, 0, err diff --git a/pkg/pgclient/config_test.go b/pkg/pgclient/config_test.go index 9f7d354d71..e9f0fa5cbd 100644 --- a/pkg/pgclient/config_test.go +++ b/pkg/pgclient/config_test.go @@ -58,7 +58,7 @@ func TestConfig_GetConnectionStr(t *testing.T) { UsesHA: false, DbUri: "", }, - want: fmt.Sprintf("application_name=%s host=localhost port=5433 user=postgres dbname=timescale1 password='Timescale123' sslmode=require connect_timeout=120", DefaultApp), + want: fmt.Sprintf("postgresql://postgres:Timescale123@localhost:5433/timescale1?application_name=%s&sslmode=require&connect_timeout=120", DefaultApp), wantErr: false, err: nil, }, @@ -132,7 +132,7 @@ func TestConfig_GetConnectionStr(t *testing.T) { UsesHA: false, DbUri: "", }, - want: fmt.Sprintf("application_name=%s host=localhost port=5432 user=postgres dbname=timescale password='' sslmode=require connect_timeout=3600", DefaultApp), + want: fmt.Sprintf("postgresql://postgres:@localhost:5432/timescale?application_name=%s&sslmode=require&connect_timeout=3600", DefaultApp), wantErr: false, err: nil, }, @@ -162,6 +162,32 @@ func TestConfig_GetConnectionStr(t *testing.T) { wantErr: false, err: nil, }, + { + name: "Testcase with db-uri that has no question mark in it", + fields: fields{ + App: DefaultApp, + Host: "localhost", + Port: 5432, + User: "postgres", + Password: "", + Database: "timescale", + SslMode: "require", + DbConnectRetries: 0, + DbConnectionTimeout: defaultConnectionTime, + AsyncAcks: false, + ReportInterval: 0, + LabelsCacheSize: 0, + MetricsCacheSize: 0, + SeriesCacheSize: 0, + WriteConnectionsPerProc: 1, + MaxConnections: 0, + UsesHA: false, + DbUri: "postgres://postgres:password@localhost:5432/postgres", + }, + want: "postgres://postgres:password@localhost:5432/postgres", + wantErr: false, + err: nil, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -187,13 +213,17 @@ func TestConfig_GetConnectionStr(t *testing.T) { UsesHA: tt.fields.UsesHA, DbUri: tt.fields.DbUri, } - got, err := cfg.GetConnectionStr() + err := cfg.validateConnectionSettings() if (err != nil) != tt.wantErr || err != tt.err { - t.Errorf("GetConnectionStr() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("validateConnectionSettings() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { return } + got := cfg.GetConnectionStr() if got != tt.want { - t.Errorf("GetConnectionStr() got = %v, want %v", got, tt.want) + t.Errorf("GetConnectionStr() \ngot %v \nwant %v", got, tt.want) } }) } diff --git a/pkg/runner/client.go b/pkg/runner/client.go index f4415e6fcd..f8250d3b64 100644 --- a/pkg/runner/client.go +++ b/pkg/runner/client.go @@ -35,10 +35,7 @@ func CreateClient(cfg *Config, promMetrics *api.Metrics) (*pgclient.Client, erro // that upgrading TimescaleDB will not break existing connectors. // (upgrading the DB will force-close all existing connections, so we may // add a reconnect check that the DB has an appropriate version) - connStr, err := cfg.PgmodelCfg.GetConnectionStr() - if err != nil { - return nil, err - } + connStr := cfg.PgmodelCfg.GetConnectionStr() extOptions := extension.ExtensionMigrateOptions{ Install: cfg.InstallExtensions, Upgrade: cfg.UpgradeExtensions, @@ -213,10 +210,7 @@ func initElector(cfg *Config, metrics *api.Metrics) (*util.Elector, error) { return nil, fmt.Errorf("Prometheus timeout configuration must be set when using PG advisory lock") } - connStr, err := cfg.PgmodelCfg.GetConnectionStr() - if err != nil { - return nil, err - } + connStr := cfg.PgmodelCfg.GetConnectionStr() lock, err := util.NewPgLeaderLock(cfg.HaGroupLockID, connStr, getSchemaLease) if err != nil { return nil, fmt.Errorf("Error creating advisory lock\nhaGroupLockId: %d\nerr: %s\n", cfg.HaGroupLockID, err)