/
db_client_search_path.go
122 lines (99 loc) · 3.92 KB
/
db_client_search_path.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package db_client
import (
"context"
"fmt"
"log"
"strings"
"github.com/jackc/pgx/v5"
"github.com/spf13/viper"
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/pkg/constants"
"github.com/turbot/steampipe/pkg/db/db_common"
)
// SetRequiredSessionSearchPath implements Client
// if either a search-path or search-path-prefix is set in config, set the search path
// (otherwise fall back to user search path)
// this just sets the required search path for this client
// - when creating a database session, we will actually set the searchPath
func (c *DbClient) SetRequiredSessionSearchPath(ctx context.Context) error {
configuredSearchPath := viper.GetStringSlice(constants.ArgSearchPath)
searchPathPrefix := viper.GetStringSlice(constants.ArgSearchPathPrefix)
// strip empty elements from search path and prefix
configuredSearchPath = helpers.RemoveFromStringSlice(configuredSearchPath, "")
searchPathPrefix = helpers.RemoveFromStringSlice(searchPathPrefix, "")
// default required path to user search path
requiredSearchPath := c.userSearchPath
// store custom search path and search path prefix
c.searchPathPrefix = searchPathPrefix
// if a search path was passed, use that
if len(configuredSearchPath) > 0 {
requiredSearchPath = configuredSearchPath
}
// add in the prefix if present
requiredSearchPath = db_common.AddSearchPathPrefix(searchPathPrefix, requiredSearchPath)
requiredSearchPath = db_common.EnsureInternalSchemaSuffix(requiredSearchPath)
// if either configuredSearchPath or searchPathPrefix are set, store requiredSearchPath as customSearchPath
if len(configuredSearchPath)+len(searchPathPrefix) > 0 {
c.customSearchPath = requiredSearchPath
} else {
// otherwise clear it
c.customSearchPath = nil
}
return nil
}
func (c *DbClient) LoadUserSearchPath(ctx context.Context) error {
conn, err := c.managementPool.Acquire(ctx)
if err != nil {
return err
}
defer conn.Release()
return c.loadUserSearchPath(ctx, conn.Conn())
}
func (c *DbClient) loadUserSearchPath(ctx context.Context, connection *pgx.Conn) error {
// load the user search path
userSearchPath, err := db_common.GetUserSearchPath(ctx, connection)
if err != nil {
return err
}
// update the cached value
c.userSearchPath = userSearchPath
return nil
}
// GetRequiredSessionSearchPath implements Client
func (c *DbClient) GetRequiredSessionSearchPath() []string {
if c.customSearchPath != nil {
return c.customSearchPath
}
return c.userSearchPath
}
func (c *DbClient) GetCustomSearchPath() []string {
return c.customSearchPath
}
// ensure the search path for the database session is as required
func (c *DbClient) ensureSessionSearchPath(ctx context.Context, session *db_common.DatabaseSession) error {
log.Printf("[TRACE] ensureSessionSearchPath")
// update the stored value of user search path
// this might have changed if a connection has been added/removed
if err := c.loadUserSearchPath(ctx, session.Connection.Conn()); err != nil {
return err
}
// get the required search path which is either a custom search path (if present) or the user search path
requiredSearchPath := c.GetRequiredSessionSearchPath()
// now determine whether the session search path is the same as the required search path
// if so, return
if strings.Join(session.SearchPath, ",") == strings.Join(requiredSearchPath, ",") {
log.Printf("[TRACE] session search path is already correct - nothing to do")
return nil
}
// so we need to set the search path
log.Printf("[TRACE] session search path will be updated to %s", strings.Join(c.customSearchPath, ","))
err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error {
_, err := tx.Exec(ctx, fmt.Sprintf("set search_path to %s", strings.Join(db_common.PgEscapeSearchPath(requiredSearchPath), ",")))
return err
})
if err == nil {
// update the session search path property
session.SearchPath = requiredSearchPath
}
return err
}