Skip to content

Commit

Permalink
Add support for user functions (#608)
Browse files Browse the repository at this point in the history
* Add initial support for functions
* Show functions definitions
* Fix client tests
* Fix schema objects search
* Perform partial matching for functions
* Add function test
* Make sure to close client connections so that database could be dropped in tests
* Fix lint
* Allow to copy the view/functions definitions
* Nits
  • Loading branch information
sosedoff committed Dec 7, 2022
1 parent bbe9a97 commit 38051b9
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 92 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -28,7 +28,7 @@ usage:
@echo ""

test:
go test -race -cover ./pkg/...
go test -v -race -cover ./pkg/...

test-all:
@./script/test_all.sh
Expand Down
26 changes: 20 additions & 6 deletions pkg/api/api.go
Expand Up @@ -342,13 +342,21 @@ func GetSchemas(c *gin.Context) {

// GetTable renders table information
func GetTable(c *gin.Context) {
var res *client.Result
var err error
var (
res *client.Result
err error
)

if c.Request.FormValue("type") == client.ObjTypeMaterializedView {
res, err = DB(c).MaterializedView(c.Params.ByName("table"))
} else {
res, err = DB(c).Table(c.Params.ByName("table"))
db := DB(c)
tableName := c.Params.ByName("table")

switch c.Request.FormValue("type") {
case client.ObjTypeMaterializedView:
res, err = db.MaterializedView(tableName)
case client.ObjTypeFunction:
res, err = db.Function(tableName)
default:
res, err = db.Table(tableName)
}

serveResult(c, res, err)
Expand Down Expand Up @@ -541,3 +549,9 @@ func DataExport(c *gin.Context) {
badRequest(c, err)
}
}

// GetFunction renders function information
func GetFunction(c *gin.Context) {
res, err := DB(c).Function(c.Param("id"))
serveResult(c, res, err)
}
1 change: 1 addition & 0 deletions pkg/api/routes.go
Expand Up @@ -42,6 +42,7 @@ func SetupRoutes(router *gin.Engine) {
api.GET("/tables/:table/info", GetTableInfo)
api.GET("/tables/:table/indexes", GetTableIndexes)
api.GET("/tables/:table/constraints", GetTableConstraints)
api.GET("/functions/:id", GetFunction)
api.GET("/query", RunQuery)
api.POST("/query", RunQuery)
api.GET("/explain", ExplainQuery)
Expand Down
4 changes: 4 additions & 0 deletions pkg/client/client.go
Expand Up @@ -197,6 +197,10 @@ func (client *Client) MaterializedView(name string) (*Result, error) {
return client.query(statements.MaterializedView, name)
}

func (client *Client) Function(id string) (*Result, error) {
return client.query(statements.Function, id)
}

func (client *Client) TableRows(table string, opts RowsOptions) (*Result, error) {
schema, table := getSchemaAndTable(table)
sql := fmt.Sprintf(`SELECT * FROM "%s"."%s"`, schema, table)
Expand Down
141 changes: 106 additions & 35 deletions pkg/client/client_test.go
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"runtime"
"sort"
"testing"
"time"

Expand All @@ -32,6 +33,26 @@ func mapKeys(data map[string]*Objects) []string {
return result
}

func objectNames(data []Object) []string {
names := make([]string, len(data))
for i, obj := range data {
names[i] = obj.Name
}

sort.Strings(names)
return names
}

// assertMatches is a helper method to check if src slice contains any elements of expected slice
func assertMatches(t *testing.T, expected, src []string) {
assert.NotEqual(t, 0, len(expected))
assert.NotEqual(t, 0, len(src))

for _, val := range expected {
assert.Contains(t, src, val)
}
}

func pgVersion() (int, int) {
var major, minor int
fmt.Sscanf(os.Getenv("PGVERSION"), "%d.%d", &major, &minor)
Expand Down Expand Up @@ -118,12 +139,12 @@ func setupClient() {

func teardownClient() {
if testClient != nil {
testClient.db.Close()
testClient.Close()
}
}

func teardown() {
_, err := exec.Command(
output, err := exec.Command(
testCommands["dropdb"],
"-U", serverUser,
"-h", serverHost,
Expand All @@ -133,31 +154,28 @@ func teardown() {

if err != nil {
fmt.Println("Teardown error:", err)
fmt.Printf("%s\n", output)
}
}

func testNewClientFromUrl(t *testing.T) {
url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
client, err := NewFromUrl(url, nil)

if err != nil {
defer client.Close()
}

assert.Equal(t, nil, err)
assert.Equal(t, url, client.ConnectionString)
}
func testNewClientFromURL(t *testing.T) {
t.Run("postgres prefix", func(t *testing.T) {
url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
client, err := NewFromUrl(url, nil)

func testNewClientFromUrl2(t *testing.T) {
url := fmt.Sprintf("postgresql://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
client, err := NewFromUrl(url, nil)
assert.Equal(t, nil, err)
assert.Equal(t, url, client.ConnectionString)
assert.NoError(t, client.Close())
})

if err != nil {
defer client.Close()
}
t.Run("postgresql prefix", func(t *testing.T) {
url := fmt.Sprintf("postgresql://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
client, err := NewFromUrl(url, nil)

assert.Equal(t, nil, err)
assert.Equal(t, url, client.ConnectionString)
assert.Equal(t, nil, err)
assert.Equal(t, url, client.ConnectionString)
assert.NoError(t, client.Close())
})
}

func testClientIdleTime(t *testing.T) {
Expand Down Expand Up @@ -202,16 +220,13 @@ func testActivity(t *testing.T) {

res, err := testClient.Activity()
assert.NoError(t, err)
for _, val := range expected {
assert.Contains(t, res.Columns, val)
}
assertMatches(t, expected, res.Columns)
}

func testDatabases(t *testing.T) {
res, err := testClient.Databases()
assert.NoError(t, err)
assert.Contains(t, res, "booktown")
assert.Contains(t, res, "postgres")
assertMatches(t, []string{"booktown", "postgres"}, res)
}

func testObjects(t *testing.T) {
Expand Down Expand Up @@ -245,16 +260,44 @@ func testObjects(t *testing.T) {
"text_sorting",
}

functions := []string{
"add_shipment",
"add_two_loop",
"books_by_subject",
"compound_word",
"count_by_two",
"double_price",
"extract_all_titles",
"extract_all_titles2",
"extract_title",
"first",
"get_author",
"get_author",
"get_customer_id",
"get_customer_name",
"html_linebreaks",
"in_stock",
"isbn_to_title",
"mixed",
"raise_test",
"ship_item",
"stock_amount",
"test",
"title",
"triple_price",
}

assert.NoError(t, err)
assert.Equal(t, []string{"schema", "name", "type", "owner", "comment"}, res.Columns)
assert.Equal(t, []string{"oid", "schema", "name", "type", "owner", "comment"}, res.Columns)
assert.Equal(t, []string{"public"}, mapKeys(objects))
assert.Equal(t, tables, objects["public"].Tables)
assert.Equal(t, []string{"recent_shipments", "stock_view"}, objects["public"].Views)
assert.Equal(t, []string{"author_ids", "book_ids", "shipments_ship_id_seq", "subject_ids"}, objects["public"].Sequences)
assert.Equal(t, tables, objectNames(objects["public"].Tables))
assertMatches(t, functions, objectNames(objects["public"].Functions))
assert.Equal(t, []string{"recent_shipments", "stock_view"}, objectNames(objects["public"].Views))
assert.Equal(t, []string{"author_ids", "book_ids", "shipments_ship_id_seq", "subject_ids"}, objectNames(objects["public"].Sequences))

major, minor := pgVersion()
if minor == 0 || minor >= 3 {
assert.Equal(t, []string{"m_stock_view"}, objects["public"].MaterializedViews)
assert.Equal(t, []string{"m_stock_view"}, objectNames(objects["public"].MaterializedViews))
} else {
t.Logf("Skipping materialized view on %d.%d\n", major, minor)
}
Expand Down Expand Up @@ -428,6 +471,33 @@ func testTableRowsOrderEscape(t *testing.T) {
assert.Nil(t, rows)
}

func testFunctions(t *testing.T) {
funcName := "get_customer_name"
funcID := ""

res, err := testClient.Objects()
assert.NoError(t, err)

for _, row := range res.Rows {
if row[2] == funcName {
funcID = row[0].(string)
break
}
}

res, err = testClient.Function("12345")
assert.NoError(t, err)
assertMatches(t, []string{"oid", "proname", "functiondef"}, res.Columns)
assert.Equal(t, 0, len(res.Rows))

res, err = testClient.Function(funcID)
assert.NoError(t, err)
assertMatches(t, []string{"oid", "proname", "functiondef"}, res.Columns)
assert.Equal(t, 1, len(res.Rows))
assert.Equal(t, funcName, res.Rows[0][1])
assert.Contains(t, res.Rows[0][len(res.Columns)-1], "SELECT INTO customer_fname, customer_lname")
}

func testResult(t *testing.T) {
t.Run("json", func(t *testing.T) {
result, err := testClient.Query("SELECT * FROM books LIMIT 1")
Expand Down Expand Up @@ -466,8 +536,8 @@ func testHistory(t *testing.T) {
t.Run("unique queries", func(t *testing.T) {
url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)

client, err := NewFromUrl(url, nil)
assert.NoError(t, err)
client, _ := NewFromUrl(url, nil)
defer client.Close()

for i := 0; i < 3; i++ {
_, err := client.Query("SELECT * FROM books WHERE id = 1")
Expand All @@ -487,6 +557,7 @@ func testReadOnlyMode(t *testing.T) {

url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase)
client, _ := NewFromUrl(url, nil)
defer client.Close()

err := client.SetReadOnlyMode()
assert.NoError(t, err)
Expand Down Expand Up @@ -522,8 +593,7 @@ func TestAll(t *testing.T) {
setup()
setupClient()

testNewClientFromUrl(t)
testNewClientFromUrl2(t)
testNewClientFromURL(t)
testClientIdleTime(t)
testTest(t)
testInfo(t)
Expand All @@ -544,6 +614,7 @@ func TestAll(t *testing.T) {
testQueryError(t)
testQueryInvalidTable(t)
testTableRowsOrderEscape(t)
testFunctions(t)
testResult(t)
testHistory(t)
testReadOnlyMode(t)
Expand Down
43 changes: 28 additions & 15 deletions pkg/client/result.go
Expand Up @@ -18,6 +18,7 @@ const (
ObjTypeView = "view"
ObjTypeMaterializedView = "materialized_view"
ObjTypeSequence = "sequence"
ObjTypeFunction = "function"
)

type (
Expand All @@ -36,11 +37,17 @@ type (
Rows []Row `json:"rows"`
}

Object struct {
OID string `json:"oid"`
Name string `json:"name"`
}

Objects struct {
Tables []string `json:"table"`
Views []string `json:"view"`
MaterializedViews []string `json:"materialized_view"`
Sequences []string `json:"sequence"`
Tables []Object `json:"table"`
Views []Object `json:"view"`
MaterializedViews []Object `json:"materialized_view"`
Functions []Object `json:"function"`
Sequences []Object `json:"sequence"`
}
)

Expand Down Expand Up @@ -154,28 +161,34 @@ func ObjectsFromResult(res *Result) map[string]*Objects {
objects := map[string]*Objects{}

for _, row := range res.Rows {
schema := row[0].(string)
name := row[1].(string)
objectType := row[2].(string)
oid := row[0].(string)
schema := row[1].(string)
name := row[2].(string)
objectType := row[3].(string)

if objects[schema] == nil {
objects[schema] = &Objects{
Tables: []string{},
Views: []string{},
MaterializedViews: []string{},
Sequences: []string{},
Tables: []Object{},
Views: []Object{},
MaterializedViews: []Object{},
Functions: []Object{},
Sequences: []Object{},
}
}

obj := Object{OID: oid, Name: name}

switch objectType {
case ObjTypeTable:
objects[schema].Tables = append(objects[schema].Tables, name)
objects[schema].Tables = append(objects[schema].Tables, obj)
case ObjTypeView:
objects[schema].Views = append(objects[schema].Views, name)
objects[schema].Views = append(objects[schema].Views, obj)
case ObjTypeMaterializedView:
objects[schema].MaterializedViews = append(objects[schema].MaterializedViews, name)
objects[schema].MaterializedViews = append(objects[schema].MaterializedViews, obj)
case ObjTypeFunction:
objects[schema].Functions = append(objects[schema].Functions, obj)
case ObjTypeSequence:
objects[schema].Sequences = append(objects[schema].Sequences, name)
objects[schema].Sequences = append(objects[schema].Sequences, obj)
}
}

Expand Down
3 changes: 3 additions & 0 deletions pkg/statements/sql.go
Expand Up @@ -38,6 +38,9 @@ var (
//go:embed sql/objects.sql
Objects string

//go:embed sql/function.sql
Function string

// Activity queries for specific PG versions
Activity = map[string]string{
"default": "SELECT * FROM pg_stat_activity WHERE datname = current_database()",
Expand Down
7 changes: 7 additions & 0 deletions pkg/statements/sql/function.sql
@@ -0,0 +1,7 @@
SELECT
p.*,
pg_get_functiondef(oid) AS functiondef
FROM
pg_catalog.pg_proc p
WHERE
oid = $1::oid

0 comments on commit 38051b9

Please sign in to comment.