-
Notifications
You must be signed in to change notification settings - Fork 263
/
create_connection.go
224 lines (193 loc) · 6.89 KB
/
create_connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
package db_local
import (
"context"
"fmt"
"log"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/spf13/viper"
"github.com/turbot/steampipe/pkg/constants"
"github.com/turbot/steampipe/pkg/constants/runtime"
"github.com/turbot/steampipe/pkg/db/db_common"
"github.com/turbot/steampipe/pkg/statushooks"
"github.com/turbot/steampipe/pkg/utils"
"github.com/turbot/steampipe/sperr"
)
func getLocalSteampipeConnectionString(opts *CreateDbOptions) (string, error) {
if opts == nil {
opts = &CreateDbOptions{}
}
utils.LogTime("db.createDbClient start")
defer utils.LogTime("db.createDbClient end")
// load the db status
info, err := GetState()
if err != nil {
return "", err
}
if info == nil {
return "", fmt.Errorf("steampipe service is not running")
}
// if no database name is passed, use constants.DatabaseUser
if len(opts.Username) == 0 {
opts.Username = constants.DatabaseUser
}
// if no username name is passed, deduce it from the db status
if len(opts.DatabaseName) == 0 {
opts.DatabaseName = info.Database
}
// if we still don't have it, fallback to default "postgres"
if len(opts.DatabaseName) == 0 {
opts.DatabaseName = "postgres"
}
psqlInfoMap := map[string]string{
// Connect to the database using the first listen address, which is usually localhost
"host": info.Listen[0],
"port": fmt.Sprintf("%d", info.Port),
"user": opts.Username,
"dbname": opts.DatabaseName,
}
psqlInfoMap = utils.MergeMaps(psqlInfoMap, dsnSSLParams())
log.Println("[TRACE] SQLInfoMap >>>", psqlInfoMap)
psqlInfo := []string{}
for k, v := range psqlInfoMap {
psqlInfo = append(psqlInfo, fmt.Sprintf("%s=%s", k, v))
}
log.Println("[TRACE] PSQLInfo >>>", psqlInfo)
return strings.Join(psqlInfo, " "), nil
}
type CreateDbOptions struct {
DatabaseName, Username string
}
// CreateLocalDbConnection connects and returns a connection to the given database using
// the provided username
// if the database is not provided (empty), it connects to the default database in the service
// that was created during installation.
// NOTE: no session data callback is used - no session data will be present
func CreateLocalDbConnection(ctx context.Context, opts *CreateDbOptions) (*pgx.Conn, error) {
utils.LogTime("db.CreateLocalDbConnection start")
defer utils.LogTime("db.CreateLocalDbConnection end")
psqlInfo, err := getLocalSteampipeConnectionString(opts)
if err != nil {
return nil, err
}
connConfig, err := pgx.ParseConfig(psqlInfo)
if err != nil {
return nil, err
}
// set an app name so that we can track database connections from this Steampipe execution
// this is used to determine whether the database can safely be closed
connConfig.Config.RuntimeParams = map[string]string{
"application_name": runtime.PgClientAppName,
}
err = db_common.AddRootCertToConfig(&connConfig.Config, getRootCertLocation())
if err != nil {
return nil, err
}
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err != nil {
return nil, err
}
if err := db_common.WaitForConnectionPing(ctx, conn); err != nil {
return nil, err
}
return conn, nil
}
// CreateConnectionPool
func CreateConnectionPool(ctx context.Context, opts *CreateDbOptions, maxConnections int) (*pgxpool.Pool, error) {
utils.LogTime("db_client.establishConnectionPool start")
defer utils.LogTime("db_client.establishConnectionPool end")
psqlInfo, err := getLocalSteampipeConnectionString(opts)
if err != nil {
return nil, err
}
connConfig, err := pgxpool.ParseConfig(psqlInfo)
if err != nil {
return nil, err
}
const (
connMaxIdleTime = 1 * time.Minute
connMaxLifetime = 10 * time.Minute
)
connConfig.MinConns = 0
connConfig.MaxConns = int32(maxConnections)
connConfig.MaxConnLifetime = connMaxLifetime
connConfig.MaxConnIdleTime = connMaxIdleTime
// set an app name so that we can track database connections from this Steampipe execution
// this is used to determine whether the database can safely be closed
connConfig.ConnConfig.Config.RuntimeParams = map[string]string{
"application_name": runtime.PgClientAppName,
}
// this returns connection pool
dbPool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
if err != nil {
return nil, err
}
err = db_common.WaitForPool(
ctx,
dbPool,
db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second),
)
if err != nil {
return nil, err
}
return dbPool, nil
}
// createMaintenanceClient connects to the postgres server using the
// maintenance database (postgres) and superuser
// this is used in a couple of places
// 1. During installation to setup the DBMS with foreign_server, extension et.al.
// 2. During service start and stop to query the DBMS for parameters (connected clients, database name etc.)
//
// this is called immediately after the service process is started and hence
// all special handling related to service startup failures SHOULD be handled here
func createMaintenanceClient(ctx context.Context, port int) (*pgx.Conn, error) {
utils.LogTime("db_local.createMaintenanceClient start")
defer utils.LogTime("db_local.createMaintenanceClient end")
connStr := fmt.Sprintf("host=localhost port=%d user=%s dbname=postgres sslmode=disable", port, constants.DatabaseSuperUser)
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second)
defer cancel()
statushooks.SetStatus(ctx, "Waiting for connection")
conn, err := db_common.WaitForConnection(
timeoutCtx,
connStr,
db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second),
)
if err != nil {
log.Println("[TRACE] could not connect to service")
return nil, sperr.Wrap(err, sperr.WithMessage("connection setup failed"))
}
// wait for db to start accepting queries on this connection
err = db_common.WaitForConnectionPing(
timeoutCtx,
conn,
db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
db_common.WithTimeout(viper.GetDuration(constants.ArgDatabaseStartTimeout)*time.Second),
)
if err != nil {
conn.Close(ctx)
log.Println("[TRACE] Ping timed out")
return nil, sperr.Wrap(err, sperr.WithMessage("connection setup failed"))
}
// wait for recovery to complete
// the database may enter recovery mode if it detects that
// it wasn't shutdown gracefully.
// For large databases, this can take long
// We want to wait for a LONG time for this to complete
// Use the context that was given - since that is tied to os.Signal
// and can be interrupted
err = db_common.WaitForRecovery(
ctx,
conn,
db_common.WithRetryInterval(constants.DBRecoveryRetryBackoff),
)
if err != nil {
conn.Close(ctx)
log.Println("[TRACE] WaitForRecovery timed out")
return nil, sperr.Wrap(err, sperr.WithMessage("could not complete recovery"))
}
return conn, nil
}