Skip to content

Commit

Permalink
Merge pull request #36 from sev-2/feature/enhance-model-validation
Browse files Browse the repository at this point in the history
Feature : support auto validate model in rest controller
  • Loading branch information
toopay authored Jul 3, 2024
2 parents 06a8e49 + 780c94b commit f381c7c
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 19 deletions.
37 changes: 34 additions & 3 deletions controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"regexp"
"strconv"
"strings"
"time"

"github.com/sev-2/raiden/pkg/logger"
"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -158,6 +159,7 @@ func (*ControllerBase) AfterHead(ctx Context) error {
// ----- Rest Controller -----
type RestController struct {
Controller
Model any
TableName string
}

Expand Down Expand Up @@ -297,20 +299,41 @@ func (rc RestController) Options(ctx Context) error {

// Patch implements Controller.
func (rc RestController) Patch(ctx Context) error {
model := createObjectFromAnyData(rc.Model)
json.Unmarshal(ctx.RequestContext().Request.Body(), model)

if err := Validate(model); err != nil {
return err
}

return RestProxy(ctx, rc.TableName)
}

// Post implements Controller.
func (rc RestController) Post(ctx Context) error {
model := createObjectFromAnyData(rc.Model)
json.Unmarshal(ctx.RequestContext().Request.Body(), model)

if err := Validate(model); err != nil {
return err
}

return RestProxy(ctx, rc.TableName)
}

// Put implements Controller.
func (rc RestController) Put(ctx Context) error {
model := createObjectFromAnyData(rc.Model)
json.Unmarshal(ctx.RequestContext().Request.Body(), model)

if err := Validate(model); err != nil {
return err
}

return RestProxy(ctx, rc.TableName)
}

// ----- Rest Controller -----
// ----- Storage Controller -----
type StorageController struct {
Controller
BucketName string
Expand Down Expand Up @@ -532,6 +555,14 @@ func MarshallAndValidate(ctx *fasthttp.RequestCtx, controller any) error {
return nil
}

func createObjectFromAnyData(data any) any {
rt := reflect.TypeOf(data)
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
}
return reflect.New(rt).Interface()
}

// The function `setPayloadValue` sets the value of a field in a struct based on its type.
func setPayloadValue(fieldValue reflect.Value, value string) error {
switch fieldValue.Kind() {
Expand Down Expand Up @@ -591,8 +622,8 @@ func RestProxy(appCtx Context, TableName string) error {
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)

restProxyLogger.Debug("forward request", "method", string(req.Header.Method()), "uri", string(req.URI().FullURI()))
if err := fasthttp.Do(req, resp); err != nil {
restProxyLogger.Debug("forward request", "method", string(req.Header.Method()), "uri", string(req.URI().FullURI()), "header", string(req.Header.RawHeaders()), "body", string(appCtx.RequestContext().Request.Body()))
if err := fasthttp.DoTimeout(req, resp, 30*time.Second); err != nil {
return err
}

Expand Down
22 changes: 15 additions & 7 deletions pkg/generator/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ type (
}

GenerateModelInput struct {
Table objects.Table
Relations []state.Relation
Policies objects.Policies
Table objects.Table
Relations []state.Relation
Policies objects.Policies
ValidationTags state.ModelValidationTag
}
)

Expand Down Expand Up @@ -109,7 +110,7 @@ func GenerateModel(folderPath string, input *GenerateModelInput, generateFn Gene
}

// map column data
columns, importsPath := MapTableAttributes(input.Table)
columns, importsPath := MapTableAttributes(input.Table, input.ValidationTags)
rlsTag := BuildRlsTag(input.Policies, input.Table.Name, supabase.RlsTypeModel)
raidenPath := "github.com/sev-2/raiden"
importsPath = append(importsPath, raidenPath)
Expand Down Expand Up @@ -189,7 +190,7 @@ func GenerateModel(folderPath string, input *GenerateModelInput, generateFn Gene
}

// map table to column, map pg type to go type and get dependency import path
func MapTableAttributes(table objects.Table) (columns []GenerateModelColumn, importsPath []string) {
func MapTableAttributes(table objects.Table, validationTags state.ModelValidationTag) (columns []GenerateModelColumn, importsPath []string) {
importsMap := make(map[string]any)
mapPrimaryKey := map[string]bool{}
for _, k := range table.PrimaryKeys {
Expand All @@ -199,7 +200,7 @@ func MapTableAttributes(table objects.Table) (columns []GenerateModelColumn, imp
for _, c := range table.Columns {
column := GenerateModelColumn{
Name: c.Name,
Tag: buildColumnTag(c, mapPrimaryKey),
Tag: buildColumnTag(c, mapPrimaryKey, validationTags),
Type: postgres.ToGoType(postgres.DataType(c.DataType), c.IsNullable),
}

Expand Down Expand Up @@ -232,13 +233,20 @@ func MapTableAttributes(table objects.Table) (columns []GenerateModelColumn, imp
return
}

func buildColumnTag(c objects.Column, mapPk map[string]bool) string {
func buildColumnTag(c objects.Column, mapPk map[string]bool, validationTags state.ModelValidationTag) string {
var tags []string

// append json tag
jsonTag := fmt.Sprintf("json:%q", utils.ToSnakeCase(c.Name)+",omitempty")
tags = append(tags, jsonTag)

// append validate tag
if validationTags != nil {
if vTag, exist := validationTags[c.Name]; exist {
tags = append(tags, fmt.Sprintf("validate:%q", vTag))
}
}

// append column tag
columnTags := []string{
fmt.Sprintf("name:%s", c.Name),
Expand Down
19 changes: 16 additions & 3 deletions pkg/resource/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,31 @@ func Import(flags *Flags, config *raiden.Config) error {

// dry run import errors
dryRunError := []string{}
mapModelValidationTags := make(map[string]state.ModelValidationTag)

// compare resource
ImportLogger.Info("compare supabase resource and local resource")
if (flags.All() || flags.ModelsOnly) && len(appTables.Existing) > 0 {
if !flags.DryRun {
ImportLogger.Debug("start compare table")
}

for i := range appTables.New {
nt := appTables.New[i]
if nt.ValidationTags != nil {
mapModelValidationTags[nt.Table.Name] = nt.ValidationTags
}
}

// compare table
var compareTables []objects.Table
for i := range appTables.Existing {
et := appTables.Existing[i]
compareTables = append(compareTables, et.Table)

if et.ValidationTags != nil {
mapModelValidationTags[et.Table.Name] = et.ValidationTags
}
}

if err := tables.Compare(spResource.Tables, compareTables); err != nil {
Expand Down Expand Up @@ -173,7 +186,7 @@ func Import(flags *Flags, config *raiden.Config) error {
}
if !flags.DryRun {
// generate resource
if err := generateImportResource(config, &importState, flags.ProjectPath, spResource); err != nil {
if err := generateImportResource(config, &importState, flags.ProjectPath, spResource, mapModelValidationTags); err != nil {
return err
}
PrintImportReport(importReport, false)
Expand All @@ -190,7 +203,7 @@ func Import(flags *Flags, config *raiden.Config) error {
}

// ----- Generate import data -----
func generateImportResource(config *raiden.Config, importState *state.LocalState, projectPath string, resource *Resource) error {
func generateImportResource(config *raiden.Config, importState *state.LocalState, projectPath string, resource *Resource, mapModelValidationTags map[string]state.ModelValidationTag) error {
if err := generator.CreateInternalFolder(projectPath); err != nil {
return err
}
Expand All @@ -202,7 +215,7 @@ func generateImportResource(config *raiden.Config, importState *state.LocalState
go func() {
defer wg.Done()
if len(resource.Tables) > 0 {
tableInputs := tables.BuildGenerateModelInputs(resource.Tables, resource.Policies)
tableInputs := tables.BuildGenerateModelInputs(resource.Tables, resource.Policies, mapModelValidationTags)
ImportLogger.Info("start generate tables")
captureFunc := ImportDecorateFunc(tableInputs, func(item *generator.GenerateModelInput, input generator.GenerateInput) bool {
if i, ok := input.BindData.(generator.GenerateModelData); ok {
Expand Down
10 changes: 7 additions & 3 deletions pkg/resource/tables/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func getMapTableKey(schema, name string) string {
return fmt.Sprintf("%s.%s", schema, name)
}

func BuildGenerateModelInputs(tables []objects.Table, policies objects.Policies) []*generator.GenerateModelInput {
func BuildGenerateModelInputs(tables []objects.Table, policies objects.Policies, mapModelValidationTags map[string]state.ModelValidationTag) []*generator.GenerateModelInput {
mapTable := tableToMap(tables)
mapRelations := buildGenerateMapRelations(mapTable)
return buildGenerateModelInput(mapTable, mapRelations, policies)
return buildGenerateModelInput(mapTable, mapRelations, policies, mapModelValidationTags)
}

// ---- build table relation for generated -----
Expand Down Expand Up @@ -166,7 +166,7 @@ func mergeGenerateManyToManyCandidate(candidates []*ManyToManyTable, mapRelation
}

// --- attach relation to table
func buildGenerateModelInput(mapTable MapTable, mapRelations MapRelations, policies objects.Policies) []*generator.GenerateModelInput {
func buildGenerateModelInput(mapTable MapTable, mapRelations MapRelations, policies objects.Policies, mapModelValidationTags map[string]state.ModelValidationTag) []*generator.GenerateModelInput {
generateInputs := make([]*generator.GenerateModelInput, 0)
for k, v := range mapTable {
input := generator.GenerateModelInput{
Expand All @@ -182,6 +182,10 @@ func buildGenerateModelInput(mapTable MapTable, mapRelations MapRelations, polic
}
}

vTag, exist := mapModelValidationTags[input.Table.Name]
if exist && vTag != nil {
input.ValidationTags = vTag
}
generateInputs = append(generateInputs, &input)
}
return generateInputs
Expand Down
2 changes: 1 addition & 1 deletion pkg/resource/tables/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestBuildGenerateModelInputs(t *testing.T) {
err := json.Unmarshal([]byte(jsonStrData), &sourceTables)
assert.NoError(t, err)

rs := tables.BuildGenerateModelInputs(sourceTables, nil)
rs := tables.BuildGenerateModelInputs(sourceTables, nil, nil)

for _, r := range rs {
assert.Equal(t, 2, len(r.Relations))
Expand Down
4 changes: 2 additions & 2 deletions pkg/resource/tables/print_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ func GenerateDiffMessage(diffData CompareDiffResult, sRelation MapRelations, tRe
diffColumns := []string{}

// mas source column
mapSColumns, _ := generator.MapTableAttributes(diffData.SourceResource)
mapTColumns, _ := generator.MapTableAttributes(diffData.TargetResource)
mapSColumns, _ := generator.MapTableAttributes(diffData.SourceResource, nil)
mapTColumns, _ := generator.MapTableAttributes(diffData.TargetResource, nil)

// find source column
for ic := range diffData.DiffItems.ChangeColumnItems {
Expand Down
13 changes: 13 additions & 0 deletions pkg/state/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ import (
"github.com/sev-2/raiden/pkg/utils"
)

type ModelValidationTag map[string]string

type ExtractTableItem struct {
Table objects.Table
ValidationTags ModelValidationTag
ExtractedPolicies ExtractedPolicies
}

Expand Down Expand Up @@ -75,6 +78,7 @@ func buildTableFromModel(model any) (ei ExtractTableItem) {
}

ei.Table.Name = raiden.GetTableName(model)
ei.ValidationTags = make(ModelValidationTag)

// add metadata
metadataField, isExist := modelType.FieldByName("Metadata")
Expand Down Expand Up @@ -114,6 +118,10 @@ func buildTableFromModel(model any) (ei ExtractTableItem) {
// set is unique to false if is primary key
c.IsUnique = false
}

if vTag := field.Tag.Get("validate"); len(vTag) > 0 {
ei.ValidationTags[c.Name] = vTag
}
}

if join := field.Tag.Get("join"); len(join) > 0 {
Expand Down Expand Up @@ -151,6 +159,7 @@ func buildTableFromState(model any, state TableState) (ei ExtractTableItem) {
// Get the reflect.Type of the struct
ei.Table = state.Table
ei.Table.Name = raiden.GetTableName(model)
ei.ValidationTags = make(ModelValidationTag)

// map column for make check if column exist and reuse default
mapColumn := make(map[string]objects.Column)
Expand Down Expand Up @@ -231,6 +240,10 @@ func buildTableFromState(model any, state TableState) (ei ExtractTableItem) {
}
}

if vTag := field.Tag.Get("validate"); len(vTag) > 0 {
ei.ValidationTags[c.Name] = vTag
}

columns = append(columns, c)
}

Expand Down
1 change: 1 addition & 0 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ func (r *router) registerRestHandler(route *Route) {

restController := RestController{
Controller: route.Controller,
Model: route.Model,
TableName: GetTableName(route.Model),
}

Expand Down

0 comments on commit f381c7c

Please sign in to comment.