Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Fix DB URI issues for URIs with missing ? parameter settings separator
Browse files Browse the repository at this point in the history
  • Loading branch information
antekresic committed Jun 2, 2021
1 parent d30b6ac commit ebe7ad8
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 48 deletions.
28 changes: 8 additions & 20 deletions pkg/pgclient/client.go
Expand Up @@ -69,38 +69,26 @@ func NewClient(cfg *Config, mt tenancy.Authorizer, schemaLocker LockFunc, readOn
}

func getPgConfig(cfg *Config) (*pgxpool.Config, int, error) {
connectionStr, err := cfg.GetConnectionStr()
if err != nil {
return nil, 0, err
}

minConnections, maxConnections, numCopiers, err := cfg.GetNumConnections()
if err != nil {
log.Error("msg", "configuring number of connections", "err", util.MaskPassword(err.Error()))
return nil, numCopiers, err
}

var (
pgConfig *pgxpool.Config
connectionArgsFmt string
)
if cfg.DbUri == defaultDBUri {
connectionArgsFmt = "%s pool_max_conns=%d pool_min_conns=%d statement_cache_capacity=%d"
} else {
connectionArgsFmt = "%s&pool_max_conns=%d&pool_min_conns=%d&statement_cache_capacity=%d"
}

// Using the PGX default of 512 for statement cache capacity.
statementCacheCapacity := 512
connectionStringWithArgs := fmt.Sprintf(connectionArgsFmt, connectionStr, maxConnections, minConnections, statementCacheCapacity)
pgConfig, err = pgxpool.ParseConfig(connectionStringWithArgs)
connectionStr := cfg.GetConnectionStr()
pgConfig, err := pgxpool.ParseConfig(connectionStr)
if err != nil {
log.Error("msg", "configuring connection", "err", util.MaskPassword(err.Error()))
return nil, numCopiers, err
}

// Configure the number of connections and statement cache capacity.
pgConfig.MinConns = int32(minConnections)
pgConfig.MaxConns = int32(maxConnections)

var statementCacheLog string
if cfg.EnableStatementsCache {
// Using the PGX default of 512 for statement cache capacity.
statementCacheCapacity := 512
pgConfig.AfterRelease = observeStatementCacheState
statementCacheEnabled.Set(1)
statementCacheCap.Set(float64(statementCacheCapacity))
Expand Down
53 changes: 38 additions & 15 deletions pkg/pgclient/config.go
Expand Up @@ -89,24 +89,50 @@ func ParseFlags(fs *flag.FlagSet, cfg *Config) *Config {
}

func Validate(cfg *Config, lcfg limits.Config) error {
if err := cfg.validateConnectionSettings(); err != nil {
return err
}
return cache.Validate(&cfg.CacheConfig, lcfg)
}

// GetConnectionStr returns a Postgres connection string
func (cfg *Config) GetConnectionStr() (string, error) {
// if DBURI is default build the connStr with DB flags
// else as DBURI isn't default check if db flags are default if we notice DBURI + DB flags not default give an error
// Now as DBURI isn't default and DB flags are default build a connStr for DBURI.
// validateConnectionSettings checks that we are not using both a DB URI and
// DB configuration flags
func (cfg Config) validateConnectionSettings() error {
// If we are using DB URI, nothing to check.
if cfg.DbUri == defaultDBUri {
return fmt.Sprintf("application_name=%s host=%v port=%v user=%v dbname=%v password='%v' sslmode=%v connect_timeout=%d",
cfg.AppName, cfg.Host, cfg.Port, cfg.User, cfg.Database, cfg.Password, cfg.SslMode, int(cfg.DbConnectionTimeout.Seconds())), nil
} else if cfg.AppName != DefaultApp || cfg.Database != defaultDBName || cfg.Host != defaultDBHost || cfg.Port != defaultDBPort ||
cfg.User != defaultDBUser || cfg.Password != defaultDBPassword || cfg.SslMode != defaultSSLMode ||
return nil
}

// If using DB URI, check if any DB flags are supplied.
if cfg.AppName != DefaultApp ||
cfg.Database != defaultDBName ||
cfg.Host != defaultDBHost ||
cfg.Port != defaultDBPort ||
cfg.User != defaultDBUser ||
cfg.Password != defaultDBPassword ||
cfg.SslMode != defaultSSLMode ||
cfg.DbConnectionTimeout != defaultConnectionTime {
return "", excessDBFlagsError
return excessDBFlagsError
}

return cfg.DbUri, nil
return nil
}

// GetConnectionStr returns a Postgres connection string
func (cfg *Config) GetConnectionStr() string {
// If DB URI is not supplied, generate one from DB flags.
if cfg.DbUri == defaultDBUri {
return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?application_name=%s&sslmode=%v&connect_timeout=%d",
cfg.User,
cfg.Password,
cfg.Host,
cfg.Port,
cfg.Database,
cfg.AppName,
cfg.SslMode,
int(cfg.DbConnectionTimeout.Seconds()))
}
return cfg.DbUri
}

func (cfg *Config) GetNumConnections() (min int, max int, numCopiers int, err error) {
Expand All @@ -117,10 +143,7 @@ func (cfg *Config) GetNumConnections() (min int, max int, numCopiers int, err er
perProc := cfg.WriteConnectionsPerProc
max = cfg.MaxConnections
if max < 1 {
connStr, err := cfg.GetConnectionStr()
if err != nil {
return 0, 0, 0, err
}
connStr := cfg.GetConnectionStr()
conn, err := pgx.Connect(context.Background(), connStr)
if err != nil {
return 0, 0, 0, err
Expand Down
40 changes: 35 additions & 5 deletions pkg/pgclient/config_test.go
Expand Up @@ -58,7 +58,7 @@ func TestConfig_GetConnectionStr(t *testing.T) {
UsesHA: false,
DbUri: "",
},
want: fmt.Sprintf("application_name=%s host=localhost port=5433 user=postgres dbname=timescale1 password='Timescale123' sslmode=require connect_timeout=120", DefaultApp),
want: fmt.Sprintf("postgresql://postgres:Timescale123@localhost:5433/timescale1?application_name=%s&sslmode=require&connect_timeout=120", DefaultApp),
wantErr: false,
err: nil,
},
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestConfig_GetConnectionStr(t *testing.T) {
UsesHA: false,
DbUri: "",
},
want: fmt.Sprintf("application_name=%s host=localhost port=5432 user=postgres dbname=timescale password='' sslmode=require connect_timeout=3600", DefaultApp),
want: fmt.Sprintf("postgresql://postgres:@localhost:5432/timescale?application_name=%s&sslmode=require&connect_timeout=3600", DefaultApp),
wantErr: false,
err: nil,
},
Expand Down Expand Up @@ -162,6 +162,32 @@ func TestConfig_GetConnectionStr(t *testing.T) {
wantErr: false,
err: nil,
},
{
name: "Testcase with db-uri that has no question mark in it",
fields: fields{
App: DefaultApp,
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "",
Database: "timescale",
SslMode: "require",
DbConnectRetries: 0,
DbConnectionTimeout: defaultConnectionTime,
AsyncAcks: false,
ReportInterval: 0,
LabelsCacheSize: 0,
MetricsCacheSize: 0,
SeriesCacheSize: 0,
WriteConnectionsPerProc: 1,
MaxConnections: 0,
UsesHA: false,
DbUri: "postgres://postgres:password@localhost:5432/postgres",
},
want: "postgres://postgres:password@localhost:5432/postgres",
wantErr: false,
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -187,13 +213,17 @@ func TestConfig_GetConnectionStr(t *testing.T) {
UsesHA: tt.fields.UsesHA,
DbUri: tt.fields.DbUri,
}
got, err := cfg.GetConnectionStr()
err := cfg.validateConnectionSettings()
if (err != nil) != tt.wantErr || err != tt.err {
t.Errorf("GetConnectionStr() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("validateConnectionSettings() error = %v, wantErr %v", err, tt.wantErr)
}

if tt.wantErr {
return
}
got := cfg.GetConnectionStr()
if got != tt.want {
t.Errorf("GetConnectionStr() got = %v, want %v", got, tt.want)
t.Errorf("GetConnectionStr() \ngot %v \nwant %v", got, tt.want)
}
})
}
Expand Down
10 changes: 2 additions & 8 deletions pkg/runner/client.go
Expand Up @@ -35,10 +35,7 @@ func CreateClient(cfg *Config, promMetrics *api.Metrics) (*pgclient.Client, erro
// that upgrading TimescaleDB will not break existing connectors.
// (upgrading the DB will force-close all existing connections, so we may
// add a reconnect check that the DB has an appropriate version)
connStr, err := cfg.PgmodelCfg.GetConnectionStr()
if err != nil {
return nil, err
}
connStr := cfg.PgmodelCfg.GetConnectionStr()
extOptions := extension.ExtensionMigrateOptions{
Install: cfg.InstallExtensions,
Upgrade: cfg.UpgradeExtensions,
Expand Down Expand Up @@ -213,10 +210,7 @@ func initElector(cfg *Config, metrics *api.Metrics) (*util.Elector, error) {
return nil, fmt.Errorf("Prometheus timeout configuration must be set when using PG advisory lock")
}

connStr, err := cfg.PgmodelCfg.GetConnectionStr()
if err != nil {
return nil, err
}
connStr := cfg.PgmodelCfg.GetConnectionStr()
lock, err := util.NewPgLeaderLock(cfg.HaGroupLockID, connStr, getSchemaLease)
if err != nil {
return nil, fmt.Errorf("Error creating advisory lock\nhaGroupLockId: %d\nerr: %s\n", cfg.HaGroupLockID, err)
Expand Down

0 comments on commit ebe7ad8

Please sign in to comment.