Skip to content

Commit

Permalink
Merge pull request #29 from sev-2/fix/v1.0.0-beta.2
Browse files Browse the repository at this point in the history
Fix/v1.0.0 beta.2
  • Loading branch information
toopay authored Jun 7, 2024
2 parents fdb1bc7 + f4b470c commit 48a466c
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 22 deletions.
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ github.com/fasthttp/router v1.4.22 h1:qwWcYBbndVDwts4dKaz+A2ehsnbKilmiP6pUhXBfYK
github.com/fasthttp/router v1.4.22/go.mod h1:KeMvHLqhlB9vyDWD5TSvTccl9qeWrjSSiTJrJALHKV0=
github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8=
github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM=
github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
Expand Down
36 changes: 24 additions & 12 deletions pkg/generator/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,7 @@ func ExtractRpcFunction(fn *objects.Function) (result ExtractRpcDataResult, err

// normalize aliases
if e := RpcNormalizeTableAliases(mapScannedTable); e != nil {
if e != nil {
err = e
return
}
err = e
return
}

Expand Down Expand Up @@ -349,13 +346,19 @@ func ExtractRpcParam(fn *objects.Function) (params []raiden.RpcParam, usePrefix
}

// loop for create rpc param and add to params variable
paramsUsePrefix := []string{}
paramsInCount := 0
for i := range fn.Args {
fa := fn.Args[i]
if fa.Mode != "in" {
continue
}

usePrefix = strings.HasPrefix(fa.Name, raiden.DefaultRpcParamPrefix)
paramsInCount++
if strings.HasPrefix(strings.ToLower(fa.Name), raiden.DefaultRpcParamPrefix) {
paramsUsePrefix = append(paramsUsePrefix, fa.Name)
}

fieldName := strings.ReplaceAll(fa.Name, raiden.DefaultRpcParamPrefix, "")
p := raiden.RpcParam{
Name: fieldName,
Expand All @@ -379,6 +382,8 @@ func ExtractRpcParam(fn *objects.Function) (params []raiden.RpcParam, usePrefix
params = append(params, p)
}

usePrefix = len(paramsUsePrefix) == paramsInCount

return
}

Expand All @@ -393,7 +398,9 @@ func ExtractRpcTable(def string) (string, map[string]*RpcScannedTable, error) {
var foundTable = &RpcScannedTable{}

for _, f := range dFields {
f = strings.TrimRight(f, ";")
k := strings.ToUpper(f)

switch lastField {
case postgres.From:
if postgres.IsReservedKeyword(k) {
Expand All @@ -420,9 +427,14 @@ func ExtractRpcTable(def string) (string, map[string]*RpcScannedTable, error) {
}
foundTable.Alias = f
}
case postgres.Join:
if f == postgres.On {
lastField = f
case postgres.Inner, postgres.Outer, postgres.Left, postgres.Right:
if k == postgres.Join {
lastField += " " + postgres.Join
continue
}
case postgres.Join, postgres.InnerJoin, postgres.OuterJoin, postgres.LeftJoin, postgres.RightJoin:
if k == postgres.On {
lastField = k
continue
}

Expand Down Expand Up @@ -530,12 +542,12 @@ func bindModelToDefinition(def string, mapTable map[string]*RpcScannedTable, par

for i := range params {
p := params[i]
key := p.Name
findKey, replaceKey := p.Name, p.Name
if useParamPrefix {
key += raiden.DefaultRpcParamPrefix + key
findKey = raiden.DefaultRpcParamPrefix + findKey
}
pattern := fmt.Sprintf(`\b%s\b`, regexp.QuoteMeta(key))
def = regexp.MustCompile(pattern).ReplaceAllString(def, ":"+key)
pattern := fmt.Sprintf(`\b%s\b`, regexp.QuoteMeta(findKey))
def = regexp.MustCompile(pattern).ReplaceAllString(def, ":"+replaceKey)
}
return def
}
Expand Down
106 changes: 104 additions & 2 deletions pkg/generator/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ func TestExtractRpcData(t *testing.T) {
END;
$$ LANGUAGE plpgsql;
`,
ReturnType: "TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying)",
// ReturnType: "SETOF submission",
ReturnType: "TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying)",
Behavior: "VOLATILE",
SecurityDefiner: false,
}
Expand Down Expand Up @@ -130,6 +129,21 @@ func TestExtractRpcTable(t *testing.T) {
assert.Equal(t, 1, len(scouter.Relation))
}

func TestExtractRpcSingleTable(t *testing.T) {
definition := `
begin
return query select * from todo;
end;`

_, mapTable, err := generator.ExtractRpcTable(definition)
assert.NoError(t, err)

table, isTableExist := mapTable["todo"]
assert.True(t, isTableExist)
assert.NotNil(t, table)
assert.Equal(t, "todo", table.Name)
}

func TestNormalizeTableAlias(t *testing.T) {
mapAlias := map[string]*generator.RpcScannedTable{
"submission": {
Expand All @@ -145,5 +159,93 @@ func TestNormalizeTableAlias(t *testing.T) {
err := generator.RpcNormalizeTableAliases(mapAlias)
assert.NoError(t, err)
assert.Equal(t, "sc", mapAlias["scouter"].Alias)
}

func TestExtractRpcWithPrefix(t *testing.T) {
fn := objects.Function{
Schema: "public",
Name: "get_submissions",
Language: "plpgsql",
Definition: `begin return query
select s.id, s.created_at, sc.name as sc_name, c.name as c_name
from submission s
inner join scouter sc on s.scouter_id = sc.scouter_id
inner join candidate c on s.candidate_id = c.candidate_id
where sc.name = in_scouter_name and c.name = in_candidate_name ; end;
`,
CompleteStatement: `
CREATE OR REPLACE FUNCTION public.get_submissions(in_scouter_name character varying, in_candidate_name character varying)\n
RETURNS TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying)\n
LANGUAGE plpgsql\n
AS $function$
begin return query
select s.id, s.created_at, sc.name as sc_name, c.name as c_name from submission s
inner join scouter sc on s.scouter_id = sc.scouter_id
inner join candidate c on s.candidate_id = c.candidate_id
where sc.name = in_scouter_name and c.name = in_candidate_name ; end;
$function$\n
`,
Args: []objects.FunctionArg{
{
Mode: "in",
Name: "in_scouter_name",
TypeId: 1043,
HasDefault: false,
},
{
Mode: "in",
Name: "in_candidate_name",
TypeId: 1043,
HasDefault: false,
},
{
Mode: "table",
Name: "id",
TypeId: 23,
HasDefault: false,
},
{
Mode: "table",
Name: "created_at",
TypeId: 23,
HasDefault: false,
},
{
Mode: "table",
Name: "sc_name",
TypeId: 23,
HasDefault: false,
},
{
Mode: "table",
Name: "c_name",
TypeId: 23,
HasDefault: false,
},
},
ArgumentTypes: "in_scouter_name character varying, in_candidate_name character varying",
IdentityArgumentTypes: "in_scouter_name character varying, in_candidate_name character varying",
ReturnTypeID: 2249,
ReturnType: "TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying)",
ReturnTypeRelationID: 0,
IsSetReturningFunction: true,
Behavior: string(raiden.RpcBehaviorVolatile),
SecurityDefiner: false,
ConfigParams: nil,
}

result, err := generator.ExtractRpcFunction(&fn)
assert.NoError(t, err)

assert.Equal(t, fn.Name, result.Rpc.Name)
assert.Equal(t, raiden.DefaultRpcSchema, result.Rpc.Schema)
assert.Equal(t, raiden.RpcBehaviorVolatile, result.Rpc.Behavior)

assert.Equal(t, raiden.RpcSecurityTypeInvoker, result.Rpc.SecurityType)
assert.Equal(t, raiden.RpcReturnDataTypeTable, result.Rpc.ReturnType)
assert.Equal(t, fn.ReturnType, result.OriginalReturnType)
assert.Equal(t, 3, len(result.MapScannedTable))

expectedDefinition := "begin return query select s.id, s.created_at, sc.name as sc_name, c.name as c_name from :s s inner join :sc sc on s.scouter_id = sc.scouter_id inner join :c c on s.candidate_id = c.candidate_id where sc.name = :scouter_name and c.name = :candidate_name ; end;"
assert.Equal(t, expectedDefinition, result.Rpc.Definition)
}
1 change: 0 additions & 1 deletion pkg/resource/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ func Apply(flags *Flags, config *raiden.Config) error {
}

ApplyLogger.Info("finish build migrate data")

if !flags.DryRun {
migrateErr := Migrate(config, &localState, flags.ProjectPath, &migrateData)
if len(migrateErr) > 0 {
Expand Down
6 changes: 3 additions & 3 deletions pkg/resource/rpc/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ func Compare(source []objects.Function, target []objects.Function) error {
}

func CompareList(sourceFn []objects.Function, targetFn []objects.Function) (diffResult []CompareDiffResult, err error) {
mapTargetFn := make(map[int]objects.Function)
mapTargetFn := make(map[string]objects.Function)
for i := range targetFn {
f := targetFn[i]
mapTargetFn[f.ID] = f
mapTargetFn[f.Name] = f
}

for i := range sourceFn {
s := sourceFn[i]

t, isExist := mapTargetFn[s.ID]
t, isExist := mapTargetFn[s.Name]
if !isExist {
continue
}
Expand Down
18 changes: 18 additions & 0 deletions pkg/state/rpc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package state

import (
"errors"
"fmt"
"regexp"
"strings"

"github.com/sev-2/raiden"
"github.com/sev-2/raiden/pkg/supabase/objects"
)
Expand Down Expand Up @@ -55,6 +60,19 @@ func BindRpcFunction(rpc raiden.Rpc, fn *objects.Function) (err error) {
fn.Name = rpc.GetName()
fn.Schema = rpc.GetSchema()
fn.CompleteStatement = rpc.GetCompleteStmt()

// validate definition query
matches := regexp.MustCompile(`:\w+`).FindAllString(fn.CompleteStatement, -1)
if len(matches) > 0 {
var errMsg string
if len(matches) > 1 {
errMsg = fmt.Sprintf("rpc %q is invalid, There are %q keys that are not mapped with any parameters or models.", rpc.GetName(), strings.Join(matches, ","))
} else {
errMsg = fmt.Sprintf("rpc %q is invalid, There is %q key that is not mapped with any parameters or models.", rpc.GetName(), matches[0])
}
return errors.New(errMsg)
}

return
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/supabase/client/net/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ func SendRequest(method string, url string, body []byte, timeout time.Duration,
return nil, err
}

Logger.Trace("net.response", "body", string(body))
if resInterceptor != nil {
if err := resInterceptor(resp); err != nil {
return body, err
Expand All @@ -114,6 +113,7 @@ func SendRequest(method string, url string, body []byte, timeout time.Duration,
return nil, err
}

Logger.Trace("net.response", "body", string(body))
return body, nil
}

Expand Down
8 changes: 8 additions & 0 deletions pkg/supabase/drivers/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/sev-2/raiden"
"github.com/sev-2/raiden/pkg/logger"
Expand Down Expand Up @@ -54,3 +55,10 @@ func ExecuteQuery[T any](baseUrl, projectId, query string, reqInterceptor net.Re

return net.Post[T](url, pByte, net.DefaultTimeout, reqInterceptor, resInterceptor)
}

func cleanupQueryParam(q string) string {
cleanQuery := strings.ReplaceAll(q, "\n", " ")
cleanQuery = strings.ReplaceAll(cleanQuery, "\t", " ")
cleanQuery = strings.ReplaceAll(cleanQuery, " ", "")
return cleanQuery
}
3 changes: 3 additions & 0 deletions pkg/supabase/drivers/cloud/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ func CreateFunction(cfg *raiden.Config, fn objects.Function) (objects.Function,
return objects.Function{}, nil
}

sql = cleanupQueryParam(sql)

_, err = ExecuteQuery[any](cfg.SupabaseApiUrl, cfg.ProjectId, sql, DefaultAuthInterceptor(cfg.AccessToken), nil)
if err != nil {
return objects.Function{}, fmt.Errorf("create new function %s error : %s", fn.Name, err)
Expand Down Expand Up @@ -78,6 +80,7 @@ func UpdateFunction(cfg *raiden.Config, fn objects.Function) error {
if err != nil {
return err
}
updateSql = cleanupQueryParam(updateSql)
_, err = ExecuteQuery[any](cfg.SupabaseApiUrl, cfg.ProjectId, updateSql, DefaultAuthInterceptor(cfg.AccessToken), nil)
if err != nil {
return fmt.Errorf("update function %s error : %s", fn.Name, err)
Expand Down
8 changes: 5 additions & 3 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ func buildRpcReturnTable(returnReflectType reflect.Type, rpc Rpc) (q string, err
return
}

pq, ep := p.ToQuery(rpc.UseParamPrefix())
pq, ep := p.ToQuery(false)
if ep != nil {
return q, ep
}
Expand Down Expand Up @@ -706,10 +706,12 @@ func buildRpcDefinition(rpc Rpc) string {
for i := range params {
p := params[i]
key := p.Name
replaceKey := key
if rpc.UseParamPrefix() {
key += DefaultRpcParamPrefix + key
replaceKey = DefaultRpcParamPrefix + key
}
definition = utils.MatchReplacer(definition, ":"+key, key)

definition = utils.MatchReplacer(definition, ":"+key, replaceKey)
}

return definition
Expand Down

0 comments on commit 48a466c

Please sign in to comment.