Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion internal/db/advisors/advisors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
"strings"

"github.com/go-errors/errors"
"github.com/jackc/pgconn"
Expand Down Expand Up @@ -109,7 +110,12 @@ func queryLints(ctx context.Context, conn *pgx.Conn) ([]Lint, error) {
}
}()

rows, err := tx.Query(ctx, lintsSQL)
setupSQL, querySQL := splitLintsSQL()
if _, err := tx.Exec(ctx, setupSQL); err != nil {
return nil, errors.Errorf("failed to prepare lint session: %w", err)
}

rows, err := tx.Query(ctx, querySQL)
if err != nil {
return nil, errors.Errorf("failed to query lints: %w", err)
}
Expand Down Expand Up @@ -145,6 +151,14 @@ func queryLints(ctx context.Context, conn *pgx.Conn) ([]Lint, error) {
return lints, nil
}

func splitLintsSQL() (string, string) {
setupSQL, querySQL, found := strings.Cut(lintsSQL, ";\n\n")
if !found {
return "", lintsSQL
}
return setupSQL, querySQL
}

func fetchSecurityAdvisors(ctx context.Context, projectRef string) ([]Lint, error) {
resp, err := utils.GetSupabase().V1GetSecurityAdvisorsWithResponse(ctx, projectRef, &api.V1GetSecurityAdvisorsParams{})
if err != nil {
Expand Down
43 changes: 37 additions & 6 deletions internal/db/advisors/advisors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ func TestQueryLints(t *testing.T) {
t.Run("parses lint results from local database", func(t *testing.T) {
utils.Config.Hostname = "127.0.0.1"
utils.Config.Db.Port = 5432
setupSQL, querySQL := splitLintsSQL()
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("begin").Reply("BEGIN").
Query(lintsSQL).
Query(setupSQL).
Reply("SET").
Query(querySQL).
Reply("SELECT 1",
[]any{
"rls_disabled_in_public",
Expand All @@ -59,10 +62,13 @@ func TestQueryLints(t *testing.T) {
})

t.Run("handles empty results", func(t *testing.T) {
setupSQL, querySQL := splitLintsSQL()
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("begin").Reply("BEGIN").
Query(lintsSQL).
Query(setupSQL).
Reply("SET").
Query(querySQL).
Reply("SELECT 0").
Query("rollback").Reply("ROLLBACK")
// Run test
Expand All @@ -72,16 +78,32 @@ func TestQueryLints(t *testing.T) {
})

t.Run("handles query error", func(t *testing.T) {
setupSQL, querySQL := splitLintsSQL()
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("begin").Reply("BEGIN").
Query(lintsSQL).
Query(setupSQL).
Reply("SET").
Query(querySQL).
ReplyError("42601", "syntax error").
Query("rollback").Reply("ROLLBACK")
// Run test
_, err := queryLints(context.Background(), conn.MockClient(t))
assert.Error(t, err)
})

t.Run("handles setup error", func(t *testing.T) {
setupSQL, _ := splitLintsSQL()
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("begin").Reply("BEGIN").
Query(setupSQL).
ReplyError("42601", "syntax error").
Query("rollback").Reply("ROLLBACK")
// Run test
_, err := queryLints(context.Background(), conn.MockClient(t))
assert.ErrorContains(t, err, "failed to prepare lint session")
})
}

func TestFilterLints(t *testing.T) {
Expand Down Expand Up @@ -313,11 +335,14 @@ func TestRunLocalWithDbUrl(t *testing.T) {
t.Run("runs advisors against custom db-url", func(t *testing.T) {
utils.Config.Hostname = "127.0.0.1"
utils.Config.Db.Port = 5432
setupSQL, querySQL := splitLintsSQL()

conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("begin").Reply("BEGIN").
Query(lintsSQL).
Query(setupSQL).
Reply("SET").
Query(querySQL).
Reply("SELECT 1",
[]any{
"rls_disabled_in_public",
Expand All @@ -339,10 +364,13 @@ func TestRunLocalWithDbUrl(t *testing.T) {
})

t.Run("returns no issues for empty results", func(t *testing.T) {
setupSQL, querySQL := splitLintsSQL()
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("begin").Reply("BEGIN").
Query(lintsSQL).
Query(setupSQL).
Reply("SET").
Query(querySQL).
Reply("SELECT 0").
Query("rollback").Reply("ROLLBACK")

Expand All @@ -351,10 +379,13 @@ func TestRunLocalWithDbUrl(t *testing.T) {
})

t.Run("fails on error level when fail-on is set", func(t *testing.T) {
setupSQL, querySQL := splitLintsSQL()
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query("begin").Reply("BEGIN").
Query(lintsSQL).
Query(setupSQL).
Reply("SET").
Query(querySQL).
Reply("SELECT 1",
[]any{
"rls_disabled_in_public",
Expand Down
Loading