diff --git a/.gitattributes b/.gitattributes index 146bde3a..9f24fbbf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ -*.gotmpl linguist-language=Go +*.go linguist-language=Go linguist-generated=false +*.gotmpl linguist-language=Go linguist-generated=false diff --git a/Makefile b/Makefile index d5fc96cb..0fcc92f8 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ test-postgresql: install cd internal/integration && QUERYX_ENV=test queryx db:create --schema postgresql.hcl cd internal/integration && QUERYX_ENV=test queryx db:migrate --schema postgresql.hcl cd internal/integration && QUERYX_ENV=test queryx generate --schema postgresql.hcl - cd internal/integration && go test -v ./... + cd internal/integration && go test ./... # cd internal/integration && QUERYX_ENV=test queryx db:drop --schema postgresql.hcl test-mysql: install mysql-drop diff --git a/README.md b/README.md index 27ad988a..e73f7ee0 100644 --- a/README.md +++ b/README.md @@ -172,7 +172,21 @@ post, err := c.QueryPost().Where(c.PostTitle.EQ("post title")).First() Queryx supports association definition in the schema file. It also generates corresponding preload query methods to avoid "N+1" query. -## has_one and belongs_to +## belongs_to + +```hcl +model "Post" { + belongs_to "Author" { + model_name = "User" + } +} +``` + +```go +c.QueryPost().PreloadAuthor().All() +``` + +## has_one ```hcl model "User" { @@ -199,7 +213,7 @@ c.QueryUser().PreloadAccount().All() c.QueryAccount().PreloadUser().All() ``` -## has_many and belongs_to +## has_many ```hcl model "User" { diff --git a/generator/client/golang/templates/[model].gotmpl b/generator/client/golang/templates/[model].gotmpl index f53d3851..f00023d1 100644 --- a/generator/client/golang/templates/[model].gotmpl +++ b/generator/client/golang/templates/[model].gotmpl @@ -16,10 +16,10 @@ type {{ $.model.Name }} struct { {{ $b.Name | pascal }} *{{ $b.ModelName }} `json:"{{ camel $b.Name }}"` {{- end }} {{- range $h := $.model.HasMany }} - {{ $h.Name | pascal }} []*{{ $h.ModelName }} + {{ $h.Name | pascal }} []*{{ $h.ModelName }} `json:"{{ $h.Name | camel }}"` {{- end }} {{- range $h := $.model.HasOne }} - {{ $h.Name | pascal }} *{{ $h.ModelName }} + {{ $h.Name | pascal }} *{{ $h.ModelName }} `json:"{{ $h.Name | camel }}"` {{- end }} schema *queryx.Schema diff --git a/generator/client/golang/templates/[model]_query.gotmpl b/generator/client/golang/templates/[model]_query.gotmpl index accc2dff..36e0b8dd 100644 --- a/generator/client/golang/templates/[model]_query.gotmpl +++ b/generator/client/golang/templates/[model]_query.gotmpl @@ -193,8 +193,8 @@ func (q *{{ $.model.Name }}Query) Preload{{ pascal $b.Name }}() *{{ $.model.Name func (q *{{ $.model.Name }}Query) preload{{ pascal $b.Name }}(rows []*{{ $.model.Name }}) error { ids := []int64{} for _, r := range rows { - if !r.{{ $b.ModelName }}ID.Null { - ids = append(ids, r.{{ $b.ModelName }}ID.Val) + if !r.{{ $b.ForeignKey | pascal }}.Null { + ids = append(ids, r.{{ $b.ForeignKey | pascal }}.Val) } } rows1, err := q.queries.Query{{ $b.ModelName }}().Where(q.schema.{{ $b.ModelName }}ID.In(ids)).All() @@ -207,8 +207,8 @@ func (q *{{ $.model.Name }}Query) preload{{ pascal $b.Name }}(rows []*{{ $.model m[r.ID] = r } for _, r := range rows { - if !r.{{ $b.ModelName }}ID.Null { - r.{{ $b.ModelName }} = m[r.{{ $b.ModelName }}ID.Val] + if !r.{{ $b.ForeignKey | pascal }}.Null { + r.{{ $b.Name | pascal }} = m[r.{{ $b.ForeignKey | pascal }}.Val] } } @@ -240,7 +240,11 @@ func (q *{{ $.model.Name }}Query) preload{{ pascal $h.Name }}(rows []*{{ $.model m1[r.{{ $.model.Name }}ID.Val] = append(m1[r.{{ $.model.Name }}ID.Val], r) } for _, r := range rows { - r.{{ $h.Through | pascal }} = m1[r.ID] + if m1[r.ID] != nil { + r.{{ $h.Through | pascal }} = m1[r.ID] + } else { + r.{{ $h.Through | pascal }} = make([]*{{ $m }}, 0) + } } ids1 := []int64{} @@ -264,7 +268,11 @@ func (q *{{ $.model.Name }}Query) preload{{ pascal $h.Name }}(rows []*{{ $.model m3[r.{{ $.model.Name }}ID.Val] = append(m3[r.{{ $.model.Name }}ID.Val], r.{{ $h.ModelName }}) } for _, r := range rows { - r.{{ $h.Name | pascal }} = m3[r.ID] + if m3[r.ID] != nil { + r.{{ $h.Name | pascal }} = m3[r.ID] + } else { + r.{{ $h.Name | pascal }} = make([]*{{ $h.ModelName }},0) + } } {{- else }} rows1, err := q.queries.Query{{ $h.ModelName }}().Where(q.schema.{{ $h.ModelName }}{{ $.model.Name }}ID.In(ids)).All() @@ -277,7 +285,11 @@ func (q *{{ $.model.Name }}Query) preload{{ pascal $h.Name }}(rows []*{{ $.model m[r.{{ $.model.Name }}ID.Val] = append(m[r.{{ $.model.Name }}ID.Val], r) } for _, r := range rows { - r.{{ $h.Name | pascal }} = m[r.ID] + if m[r.ID] != nil { + r.{{ $h.Name | pascal }} = m[r.ID] + } else { + r.{{ $h.Name | pascal }} = make([]*{{ $h.ModelName }}, 0) + } } {{- end }} @@ -320,7 +332,7 @@ func (q *{{.model.Name}}Query) All() ([]*{{.model.Name}}, error) { } var rows []{{ $.model.Name }} {{- $var1 := $.model.Name | camel | plural }} - var {{ $var1 }} []*{{ $.model.Name }} + {{ $var1 }} := make([]*{{ $.model.Name }}, 0) query, args := q.selectStatement.ToSQL() err := q.adapter.Query(query, args...).Scan(&rows) if err != nil { @@ -328,7 +340,7 @@ func (q *{{.model.Name}}Query) All() ([]*{{.model.Name}}, error) { } if len(rows) == 0 { - return nil, err + return {{ $var1 }}, nil } for i := range rows { diff --git a/generator/client/golang/templates/queryx/bigint.go b/generator/client/golang/templates/queryx/bigint.go index bf131af4..622929dc 100644 --- a/generator/client/golang/templates/queryx/bigint.go +++ b/generator/client/golang/templates/queryx/bigint.go @@ -5,6 +5,7 @@ package queryx import ( "database/sql" "database/sql/driver" + "encoding/json" ) type BigInt struct { @@ -43,10 +44,22 @@ func (b BigInt) Value() (driver.Value, error) { return b.Val, nil } +// MarshalJSON implements the json.Marshaler interface. func (b BigInt) MarshalJSON() ([]byte, error) { - return nil, nil + if b.Null { + return json.Marshal(nil) + } + return json.Marshal(b.Val) } -func (b *BigInt) UnmarshalJSON(text []byte) error { +// UnmarshalJSON implements the json.Unmarshaler interface. +func (b *BigInt) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + b.Null = true + return nil + } + if err := json.Unmarshal(data, &b.Val); err != nil { + return err + } return nil } diff --git a/generator/client/golang/templates/queryx/bigint_column.go b/generator/client/golang/templates/queryx/bigint_column.go index 9e57a9df..3d9e77c0 100644 --- a/generator/client/golang/templates/queryx/bigint_column.go +++ b/generator/client/golang/templates/queryx/bigint_column.go @@ -59,6 +59,11 @@ func (c *BigIntColumn) GE(v int64) *Clause { } func (c *BigIntColumn) In(v []int64) *Clause { + if len(v) == 0 { + return &Clause{ + fragment: fmt.Sprintf("1=0"), + } + } return &Clause{ fragment: fmt.Sprintf("%s.%s IN (?)", c.Table.Name, c.Name), args: []interface{}{v}, diff --git a/generator/client/golang/templates/queryx/bigint_test.go b/generator/client/golang/templates/queryx/bigint_test.go new file mode 100644 index 00000000..52bdbea6 --- /dev/null +++ b/generator/client/golang/templates/queryx/bigint_test.go @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewBigInt(t *testing.T) { + b1 := NewBigInt(2) + require.Equal(t, int64(2), b1.Val) + require.Equal(t, false, b1.Null) + + b2 := NewNullableBigInt(nil) + require.Equal(t, true, b2.Null) +} + +func TestBigIntJSON(t *testing.T) { + type Foo struct { + X BigInt `json:"x"` + Y BigInt `json:"y"` + } + x := NewBigInt(2) + y := NewNullableBigInt(nil) + s := `{"x":2,"y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/client/golang/templates/queryx/boolean.go b/generator/client/golang/templates/queryx/boolean.go index 3d6aae73..66e975cf 100644 --- a/generator/client/golang/templates/queryx/boolean.go +++ b/generator/client/golang/templates/queryx/boolean.go @@ -5,6 +5,7 @@ package queryx import ( "database/sql" "database/sql/driver" + "encoding/json" ) type Boolean struct { @@ -42,3 +43,23 @@ func (b Boolean) Value() (driver.Value, error) { } return b.Val, nil } + +// MarshalJSON implements the json.Marshaler interface. +func (b Boolean) MarshalJSON() ([]byte, error) { + if b.Null { + return json.Marshal(nil) + } + return json.Marshal(b.Val) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (b *Boolean) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + b.Null = true + return nil + } + if err := json.Unmarshal(data, &b.Val); err != nil { + return err + } + return nil +} diff --git a/generator/client/golang/templates/queryx/boolean_column.go b/generator/client/golang/templates/queryx/boolean_column.go index 0db05e90..72764e3e 100644 --- a/generator/client/golang/templates/queryx/boolean_column.go +++ b/generator/client/golang/templates/queryx/boolean_column.go @@ -29,3 +29,11 @@ func (c *BooleanColumn) NE(v bool) *Clause { args: []interface{}{v}, } } + +func (b *BooleanColumn) Asc() string { + return fmt.Sprintf("%s.%s ASC", b.Table.Name, b.Name) +} + +func (b *BooleanColumn) Desc() string { + return fmt.Sprintf("%s.%s DESC", b.Table.Name, b.Name) +} diff --git a/generator/client/golang/templates/queryx/boolean_test.go b/generator/client/golang/templates/queryx/boolean_test.go new file mode 100644 index 00000000..cb89d02e --- /dev/null +++ b/generator/client/golang/templates/queryx/boolean_test.go @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewBoolean(t *testing.T) { + b1 := NewBoolean(true) + require.Equal(t, true, b1.Val) + require.Equal(t, false, b1.Null) + + b2 := NewNullableBoolean(nil) + require.Equal(t, true, b2.Null) +} + +func TestBooleanJSON(t *testing.T) { + type Foo struct { + X Boolean `json:"x"` + Y Boolean `json:"y"` + } + x := NewBoolean(true) + y := NewNullableBoolean(nil) + s := `{"x":true,"y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/client/golang/templates/queryx/date.gotmpl b/generator/client/golang/templates/queryx/date.gotmpl index a7e3870a..ec96fcc1 100644 --- a/generator/client/golang/templates/queryx/date.gotmpl +++ b/generator/client/golang/templates/queryx/date.gotmpl @@ -62,13 +62,28 @@ func (d Date) Value() (driver.Value, error) { return d.Val, nil } +// MarshalJSON implements the json.Marshaler interface. func (d Date) MarshalJSON() ([]byte, error) { if d.Null { return json.Marshal(nil) } - return json.Marshal(d.Val) + return json.Marshal(d.Val.Format("2006-01-02")) } -func (d *Date) UnmarshalJSON(text []byte) error { +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *Date) UnmarshalJSON(data []byte) error { + s := string(data) + if s == "null" || s == "" { + d.Null = true + return nil + } + + s = s[len(`"`) : len(s)-len(`"`)] + t, err := parseDate(s) + if err != nil { + return err + } + + d.Val = *t return nil } diff --git a/generator/client/golang/templates/queryx/date_column.gotmpl b/generator/client/golang/templates/queryx/date_column.gotmpl index d6ce204f..0b273594 100644 --- a/generator/client/golang/templates/queryx/date_column.gotmpl +++ b/generator/client/golang/templates/queryx/date_column.gotmpl @@ -88,3 +88,11 @@ func (c *DateColumn) GT(v string) *Clause { err: err, } } + +func (c *DateColumn) Asc() string { + return fmt.Sprintf("%s.%s ASC", c.Table.Name, c.Name) +} + +func (c *DateColumn) Desc() string { + return fmt.Sprintf("%s.%s DESC", c.Table.Name, c.Name) +} diff --git a/generator/client/golang/templates/queryx/date_test.gotmpl b/generator/client/golang/templates/queryx/date_test.gotmpl new file mode 100644 index 00000000..52c3c739 --- /dev/null +++ b/generator/client/golang/templates/queryx/date_test.gotmpl @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewDate(t *testing.T) { + d1 := NewDate("2012-11-10") + require.Equal(t, "2012-11-10", d1.Val.Format("2006-01-02")) + require.Equal(t, false, d1.Null) + + d2 := NewNullableDate(nil) + require.Equal(t, true, d2.Null) +} + +func TestDateJSON(t *testing.T) { + type Foo struct { + X Date `json:"x"` + Y Date `json:"y"` + } + x := NewDate("2012-11-10") + y := NewNullableDate(nil) + s := `{"x":"2012-11-10","y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/client/golang/templates/queryx/datetime.gotmpl b/generator/client/golang/templates/queryx/datetime.gotmpl index aa1e70ea..b084567a 100644 --- a/generator/client/golang/templates/queryx/datetime.gotmpl +++ b/generator/client/golang/templates/queryx/datetime.gotmpl @@ -70,26 +70,33 @@ func (d Datetime) Value() (driver.Value, error) { return d.Val.UTC(), nil } +// MarshalJSON implements the json.Marshaler interface. func (d Datetime) MarshalJSON() ([]byte, error) { if d.Null { return json.Marshal(nil) } - return json.Marshal(d.Val) + return json.Marshal(d.Val.UTC()) } -func (d *Datetime) UnmarshalJSON(b []byte) error { - s := string(b) +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *Datetime) UnmarshalJSON(data []byte) error { + s := string(data) if s == "null" || s == "" { d.Null = true return nil } - t:= time.Time{} - err := t.UnmarshalJSON(b) + t := time.Time{} + err := t.UnmarshalJSON(data) if err != nil { return err } - d.Val = t + + location, err := loadLocation() + if err != nil { + return err + } + d.Val = t.In(location) return nil } diff --git a/generator/client/golang/templates/queryx/datetime_column.gotmpl b/generator/client/golang/templates/queryx/datetime_column.gotmpl index 56cd5501..ffd108cd 100644 --- a/generator/client/golang/templates/queryx/datetime_column.gotmpl +++ b/generator/client/golang/templates/queryx/datetime_column.gotmpl @@ -88,3 +88,11 @@ func (c *DatetimeColumn) GT(v string) *Clause { err: err, } } + +func (c *DatetimeColumn) Asc() string { + return fmt.Sprintf("%s.%s ASC", c.Table.Name, c.Name) +} + +func (c *DatetimeColumn) Desc() string { + return fmt.Sprintf("%s.%s DESC", c.Table.Name, c.Name) +} diff --git a/generator/client/golang/templates/queryx/datetime_test.gotmpl b/generator/client/golang/templates/queryx/datetime_test.gotmpl new file mode 100644 index 00000000..30cb5e9f --- /dev/null +++ b/generator/client/golang/templates/queryx/datetime_test.gotmpl @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewDatetime(t *testing.T) { + d1 := NewDatetime("2012-12-12 11:10:09") + require.Equal(t, "2012-12-12 11:10:09", d1.Val.Format("2006-01-02 15:04:05")) + require.Equal(t, false, d1.Null) + + d2 := NewNullableDatetime(nil) + require.Equal(t, true, d2.Null) +} + +func TestDatetimeJSON(t *testing.T) { + type Foo struct { + X Datetime `json:"x"` + Y Datetime `json:"y"` + } + x := NewDatetime("2012-12-12 11:10:09") + y := NewNullableDatetime(nil) + s := `{"x":"2012-12-12T03:10:09Z","y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/client/golang/templates/queryx/float.go b/generator/client/golang/templates/queryx/float.go index bf98a70a..acfa81cb 100644 --- a/generator/client/golang/templates/queryx/float.go +++ b/generator/client/golang/templates/queryx/float.go @@ -3,7 +3,9 @@ package queryx import ( + "database/sql" "database/sql/driver" + "encoding/json" ) type Float struct { @@ -25,7 +27,10 @@ func NewNullableFloat(v *float64) Float { // Scan implements the Scanner interface. func (f *Float) Scan(value interface{}) error { - return nil + ns := sql.NullFloat64{Float64: f.Val} + err := ns.Scan(value) + f.Val, f.Null = ns.Float64, !ns.Valid + return err } // Value implements the driver Valuer interface. @@ -35,3 +40,23 @@ func (f Float) Value() (driver.Value, error) { } return float64(f.Val), nil } + +// MarshalJSON implements the json.Marshaler interface. +func (f Float) MarshalJSON() ([]byte, error) { + if f.Null { + return json.Marshal(nil) + } + return json.Marshal(f.Val) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (f *Float) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + f.Null = true + return nil + } + if err := json.Unmarshal(data, &f.Val); err != nil { + return err + } + return nil +} diff --git a/generator/client/golang/templates/queryx/float_column.go b/generator/client/golang/templates/queryx/float_column.go index 7c10f778..d1abf7ff 100644 --- a/generator/client/golang/templates/queryx/float_column.go +++ b/generator/client/golang/templates/queryx/float_column.go @@ -57,3 +57,11 @@ func (c *FloatColumn) In(v []float64) *Clause { args: []interface{}{v}, } } + +func (c *FloatColumn) Asc() string { + return fmt.Sprintf("%s.%s ASC", c.Table.Name, c.Name) +} + +func (c *FloatColumn) Desc() string { + return fmt.Sprintf("%s.%s DESC", c.Table.Name, c.Name) +} diff --git a/generator/client/golang/templates/queryx/float_test.go b/generator/client/golang/templates/queryx/float_test.go new file mode 100644 index 00000000..9880ef66 --- /dev/null +++ b/generator/client/golang/templates/queryx/float_test.go @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewFloat(t *testing.T) { + f1 := NewFloat(2.1) + require.Equal(t, 2.1, f1.Val) + require.Equal(t, false, f1.Null) + + f2 := NewNullableFloat(nil) + require.Equal(t, true, f2.Null) +} + +func TestFloatJSON(t *testing.T) { + type Foo struct { + X Float `json:"x"` + Y Float `json:"y"` + } + x := NewFloat(2.1) + y := NewNullableFloat(nil) + s := `{"x":2.1,"y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/client/golang/templates/queryx/integer.go b/generator/client/golang/templates/queryx/integer.go index 7221e18f..a1cf76dc 100644 --- a/generator/client/golang/templates/queryx/integer.go +++ b/generator/client/golang/templates/queryx/integer.go @@ -45,15 +45,17 @@ func (i Integer) Value() (driver.Value, error) { return int64(i.Val), nil } +// MarshalJSON implements the json.Marshaler interface. func (i Integer) MarshalJSON() ([]byte, error) { if i.Null { return json.Marshal(nil) } - return nil, nil + return json.Marshal(i.Val) } -func (i *Integer) UnmarshalJSON(b []byte) error { - s := string(b) +// UnmarshalJSON implements the json.Unmarshaler interface. +func (i *Integer) UnmarshalJSON(data []byte) error { + s := string(data) if s == "null" { i.Null = true return nil diff --git a/generator/client/golang/templates/queryx/integer_test.go b/generator/client/golang/templates/queryx/integer_test.go index 55c32f1a..cc12c515 100644 --- a/generator/client/golang/templates/queryx/integer_test.go +++ b/generator/client/golang/templates/queryx/integer_test.go @@ -3,18 +3,38 @@ package queryx import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" ) func TestNewInteger(t *testing.T) { - i := NewInteger(2) - require.Equal(t, int32(2), i.Val) - require.Equal(t, false, i.Null) + i1 := NewInteger(2) + require.Equal(t, int32(2), i1.Val) + require.Equal(t, false, i1.Null) + + i2 := NewNullableInteger(nil) + require.Equal(t, true, i2.Null) } -func TestNewNullableInteger(t *testing.T) { - i := NewNullableInteger(nil) - require.Equal(t, true, i.Null) +func TestIntegerJSON(t *testing.T) { + type Foo struct { + X Integer `json:"x"` + Y Integer `json:"y"` + } + x := NewInteger(2) + y := NewNullableInteger(nil) + s := `{"x":2,"y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) } diff --git a/generator/client/golang/templates/queryx/json.go b/generator/client/golang/templates/queryx/json.go index 64edcfb3..a49b6897 100644 --- a/generator/client/golang/templates/queryx/json.go +++ b/generator/client/golang/templates/queryx/json.go @@ -43,10 +43,26 @@ func (j JSON) Value() (driver.Value, error) { return json.Marshal(j.Val) } +// MarshalJSON implements the json.Marshaler interface. func (j JSON) MarshalJSON() ([]byte, error) { - return nil, nil + if j.Null { + return json.Marshal(nil) + } + return json.Marshal(j.Val) } -func (j *JSON) UnmarshalJSON(b []byte) error { +// UnmarshalJSON implements the json.Unmarshaler interface. +func (j *JSON) UnmarshalJSON(data []byte) error { + s := string(data) + if s == "{}" || s == "null" { + j.Null = true + return nil + } + m := map[string]interface{}{} + err := json.Unmarshal(data, &m) + if err != nil { + return err + } + j.Val = m return nil } diff --git a/generator/client/golang/templates/queryx/json_test.go b/generator/client/golang/templates/queryx/json_test.go index bdd28081..846067ad 100644 --- a/generator/client/golang/templates/queryx/json_test.go +++ b/generator/client/golang/templates/queryx/json_test.go @@ -1,6 +1,7 @@ package queryx import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" @@ -8,16 +9,34 @@ import ( func TestNewJSON(t *testing.T) { m := map[string]interface{}{"a": 1} - j := NewJSON(m) - require.Equal(t, m, j.Val) - require.False(t, j.Null) + j1 := NewJSON(m) + require.Equal(t, m, j1.Val) + require.False(t, j1.Null) + + j2 := NewNullableJSON(nil) + require.True(t, j2.Null) + + j3 := NewNullableJSON(m) + require.False(t, j3.Null) } -func TestNewNullableJSON(t *testing.T) { - j1 := NewNullableJSON(nil) - require.True(t, j1.Null) +func TestJSONJSON(t *testing.T) { + type Foo struct { + X JSON `json:"x"` + Y JSON `json:"y"` + } + x := NewJSON(map[string]interface{}{"a": "b"}) + y := NewNullableJSON(nil) + s := `{"x":{"a":"b"},"y":null}` - m := map[string]interface{}{"a": 1} - j2 := NewNullableJSON(m) - require.False(t, j2.Null) + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) } diff --git a/generator/client/golang/templates/queryx/string.go b/generator/client/golang/templates/queryx/string.go index e6c1f46e..eeceeddc 100644 --- a/generator/client/golang/templates/queryx/string.go +++ b/generator/client/golang/templates/queryx/string.go @@ -25,6 +25,7 @@ func NewNullableString(v *string) String { return String{Null: true} } +// Scan implements the Scanner interface. func (s *String) Scan(value interface{}) error { ns := sql.NullString{String: s.Val} err := ns.Scan(value) @@ -32,6 +33,7 @@ func (s *String) Scan(value interface{}) error { return err } +// Value implements the driver Valuer interface. func (s String) Value() (driver.Value, error) { if s.Null { return nil, nil @@ -39,6 +41,7 @@ func (s String) Value() (driver.Value, error) { return s.Val, nil } +// MarshalJSON implements the json.Marshaler interface. func (s String) MarshalJSON() ([]byte, error) { if s.Null { return json.Marshal(nil) @@ -46,32 +49,14 @@ func (s String) MarshalJSON() ([]byte, error) { return json.Marshal(s.Val) } -func (s *String) UnmarshalJSON(text []byte) error { - // ns.Valid = false - // if string(text) == "null" { - // return nil - // } - // if err := json.Unmarshal(text, &ns.String); err == nil { - // ns.Valid = true - // } - return nil -} - -func (s *String) UnmarshalText(text []byte) error { - // ns.Valid = false - // t := string(text) - // if t == "null" { - // return nil - // } - // ns.String = t - // ns.Valid = true - return nil -} - -// String implements the fmt.Stringer. -func (s *String) String() string { - if s.Null { - return "null" +// UnmarshalJSON implements the json.Unmarshaler interface. +func (s *String) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + s.Null = true + return nil } - return s.Val + if err := json.Unmarshal(data, &s.Val); err != nil { + return err + } + return nil } diff --git a/generator/client/golang/templates/queryx/string_test.go b/generator/client/golang/templates/queryx/string_test.go new file mode 100644 index 00000000..6c56148b --- /dev/null +++ b/generator/client/golang/templates/queryx/string_test.go @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewString(t *testing.T) { + s1 := NewString("ss") + require.Equal(t, "ss", s1.Val) + require.Equal(t, false, s1.Null) + + s2 := NewNullableString(nil) + require.Equal(t, true, s2.Null) +} + +func TestStringJSON(t *testing.T) { + type Foo struct { + X String `json:"x"` + Y String `json:"y"` + } + x := NewString("ss") + y := NewNullableString(nil) + s := `{"x":"ss","y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/client/golang/templates/queryx/time.gotmpl b/generator/client/golang/templates/queryx/time.gotmpl index 59fb7192..0f23991a 100644 --- a/generator/client/golang/templates/queryx/time.gotmpl +++ b/generator/client/golang/templates/queryx/time.gotmpl @@ -62,13 +62,28 @@ func (t Time) Value() (driver.Value, error) { return t.Val, nil } +// MarshalJSON implements the json.Marshaler interface. func (t Time) MarshalJSON() ([]byte, error) { if t.Null { return json.Marshal(nil) } - return json.Marshal(t.Val) + return json.Marshal(t.Val.Format("15:04:05")) } -func (t *Time) UnmarshalJSON(text []byte) error { +// UnmarshalJSON implements the json.Unmarshaler interface. +func (t *Time) UnmarshalJSON(data []byte) error { + s := string(data) + if s == "null" || s == "" { + t.Null = true + return nil + } + + s = s[len(`"`) : len(s)-len(`"`)] + tt, err := parseTime(s) + if err != nil { + return err + } + + t.Val = *tt return nil } diff --git a/generator/client/golang/templates/queryx/time_column.gotmpl b/generator/client/golang/templates/queryx/time_column.gotmpl index 742fac13..6c9b8452 100644 --- a/generator/client/golang/templates/queryx/time_column.gotmpl +++ b/generator/client/golang/templates/queryx/time_column.gotmpl @@ -89,3 +89,11 @@ func (c *TimeColumn) EQ(v string) *Clause { err: err, } } + +func (c *TimeColumn) Asc() string { + return fmt.Sprintf("%s.%s ASC", c.Table.Name, c.Name) +} + +func (c *TimeColumn) Desc() string { + return fmt.Sprintf("%s.%s DESC", c.Table.Name, c.Name) +} diff --git a/generator/client/golang/templates/queryx/time_test.gotmpl b/generator/client/golang/templates/queryx/time_test.gotmpl new file mode 100644 index 00000000..333b2fa3 --- /dev/null +++ b/generator/client/golang/templates/queryx/time_test.gotmpl @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewTime(t *testing.T) { + t1 := NewTime("12:11:10") + require.Equal(t, "12:11:10", t1.Val.Format("15:04:05")) + require.Equal(t, false, t1.Null) + + t2 := NewNullableTime(nil) + require.Equal(t, true, t2.Null) +} + +func TestTimeJSON(t *testing.T) { + type Foo struct { + X Time `json:"x"` + Y Time `json:"y"` + } + x := NewTime("12:11:10") + y := NewNullableTime(nil) + s := `{"x":"12:11:10","y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/client/golang/templates/queryx/uuid.go b/generator/client/golang/templates/queryx/uuid.go index c7982ff1..6fb84e25 100644 --- a/generator/client/golang/templates/queryx/uuid.go +++ b/generator/client/golang/templates/queryx/uuid.go @@ -41,6 +41,7 @@ func (u UUID) Value() (driver.Value, error) { return u.Val, nil } +// MarshalJSON implements the json.Marshaler interface. func (u UUID) MarshalJSON() ([]byte, error) { if u.Null { return json.Marshal(nil) @@ -48,6 +49,14 @@ func (u UUID) MarshalJSON() ([]byte, error) { return json.Marshal(u.Val) } -func (u *UUID) UnmarshalJSON(text []byte) error { +// UnmarshalJSON implements the json.Unmarshaler interface. +func (u *UUID) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + u.Null = true + return nil + } + if err := json.Unmarshal(data, &u.Val); err != nil { + return err + } return nil } diff --git a/generator/client/golang/templates/queryx/uuid_column.go b/generator/client/golang/templates/queryx/uuid_column.go index deaa87cf..eceaa2d5 100644 --- a/generator/client/golang/templates/queryx/uuid_column.go +++ b/generator/client/golang/templates/queryx/uuid_column.go @@ -28,3 +28,11 @@ func (c *UUIDColumn) EQ(v string) *Clause { args: []interface{}{v}, } } + +func (c *UUIDColumn) Asc() string { + return fmt.Sprintf("%s.%s ASC", c.Table.Name, c.Name) +} + +func (c *UUIDColumn) Desc() string { + return fmt.Sprintf("%s.%s DESC", c.Table.Name, c.Name) +} diff --git a/generator/client/golang/templates/queryx/uuid_test.go b/generator/client/golang/templates/queryx/uuid_test.go new file mode 100644 index 00000000..2ec120e7 --- /dev/null +++ b/generator/client/golang/templates/queryx/uuid_test.go @@ -0,0 +1,40 @@ +// Code generated by queryx, DO NOT EDIT. + +package queryx + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewUUID(t *testing.T) { + u1 := NewUUID("a81e44c5-7e18-4dfe-b9b3-d9280629d2ef") + require.Equal(t, "a81e44c5-7e18-4dfe-b9b3-d9280629d2ef", u1.Val) + require.Equal(t, false, u1.Null) + + u2 := NewNullableUUID(nil) + require.Equal(t, true, u2.Null) +} + +func TestUUIDJSON(t *testing.T) { + type Foo struct { + X UUID `json:"x"` + Y UUID `json:"y"` + } + x := NewUUID("a81e44c5-7e18-4dfe-b9b3-d9280629d2ef") + y := NewNullableUUID(nil) + s := `{"x":"a81e44c5-7e18-4dfe-b9b3-d9280629d2ef","y":null}` + + f1 := Foo{X: x, Y: y} + b, err := json.Marshal(f1) + require.NoError(t, err) + require.Equal(t, s, string(b)) + + var f2 Foo + err = json.Unmarshal([]byte(s), &f2) + require.NoError(t, err) + require.Equal(t, x, f2.X) + require.Equal(t, y, f2.Y) +} diff --git a/generator/generator.go b/generator/generator.go index a9475df8..5e449765 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -88,10 +88,6 @@ func (g *Generator) Generate(schema *schema.Schema) error { database := schema.Databases[0] dir := database.Name created := []string{} - err := checkRelationship(database) - if err != nil { - return err - } for _, tpl := range g.template.Templates() { name := tpl.Name() @@ -165,38 +161,3 @@ func readDir(dir string) ([]string, error) { } return files, nil } - -func checkRelationship(d *schema.Database) error { - modeMap := make(map[string]struct{}) - for i := 0; i < len(d.Models); i++ { - modeMap[d.Models[i].Name] = struct{}{} - } - for i := 0; i < len(d.Models); i++ { - for j := 0; j < len(d.Models[i].HasOne); j++ { - if len(d.Models[i].HasOne) > 0 { - if _, ok := modeMap[inflect.Pascal(inflect.Singular(d.Models[i].HasOne[j].Name))]; !ok { - return fmt.Errorf(fmt.Sprintf("the model of %s do not exist", d.Models[i].HasOne[j].Name)) - } - } - - } - for k := 0; k < len(d.Models[i].HasMany); k++ { - if len(d.Models[i].HasMany) > 0 { - if _, ok := modeMap[inflect.Pascal(inflect.Singular(d.Models[i].HasMany[k].Name))]; !ok { - return fmt.Errorf(fmt.Sprintf("the model of %s do not exist", d.Models[i].HasMany[k].Name)) - } - } - - } - - for h := 0; h < len(d.Models[i].BelongsTo); h++ { - if len(d.Models[i].BelongsTo) > 0 { - if _, ok := modeMap[inflect.Pascal(inflect.Singular(d.Models[i].BelongsTo[h].Name))]; !ok { - return fmt.Errorf(fmt.Sprintf("the model of %s do not exist", d.Models[i].BelongsTo[h].Name)) - } - } - } - } - - return nil -} diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 69a1f219..4e5fdaf4 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -10,6 +10,47 @@ import ( var c *db.QXClient +func TestQueryOne(t *testing.T) { + user, err := c.QueryUser().Create(c.ChangeUser().SetName("test")) + require.NoError(t, err) + + var row struct { + UserID int64 `db:"user_id"` + } + err = c.QueryOne("select id as user_id from users where id = ?", user.ID).Scan(&row) + require.NoError(t, err) + require.Equal(t, user.ID, row.UserID) +} + +func TestQuery(t *testing.T) { + user1, err := c.QueryUser().Create(c.ChangeUser().SetName("test1")) + require.NoError(t, err) + user2, err := c.QueryUser().Create(c.ChangeUser().SetName("test2")) + require.NoError(t, err) + + type Foo struct { + UserName string `db:"user_name"` + } + var rows []Foo + err = c.Query("select name as user_name from users where id in (?)", []int64{user1.ID, user2.ID}).Scan(&rows) + require.NoError(t, err) + require.Equal(t, []Foo{ + {user1.Name.Val}, + {user2.Name.Val}, + }, rows) +} + +func TestExec(t *testing.T) { + user, err := c.QueryUser().Create(c.ChangeUser().SetName("test")) + require.NoError(t, err) + updated, err := c.Exec("update users set name = ? where id = ?", "test1", user.ID) + require.NoError(t, err) + require.Equal(t, int64(1), updated) + deleted, err := c.Exec("delete from users where id = ?", user.ID) + require.NoError(t, err) + require.Equal(t, int64(1), deleted) +} + func TestCreate(t *testing.T) { user, err := c.QueryUser().Create(c.ChangeUser().SetName("user").SetType("admin")) require.NoError(t, err) @@ -126,6 +167,70 @@ func TestExists(t *testing.T) { require.True(t, exists) } +func TestBelongsTo(t *testing.T) { + author, err := c.QueryUser().Create(c.ChangeUser().SetName("author")) + require.NoError(t, err) + post, err := c.QueryPost().Create(c.ChangePost().SetTitle("post title").SetAuthorID(author.ID)) + require.NoError(t, err) + post, err = c.QueryPost().PreloadAuthor().Find(post.ID) + require.NoError(t, err) + require.Equal(t, author.ID, post.Author.ID) +} + +func TestAllEmpty(t *testing.T) { + _, err := c.QueryUser().DeleteAll() + require.NoError(t, err) + + users, err := c.QueryUser().All() + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, 0, len(users)) +} + +func TestInEmptySlice(t *testing.T) { + _, err := c.QueryUser().DeleteAll() + require.NoError(t, err) + users, err := c.QueryUser().Where(c.UserID.In([]int64{})).All() + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, 0, len(users)) + + users, err = c.QueryUser().Where(c.UserID.In([]int64{}).And(c.UserID.EQ(1)).And(c.UserID.In([]int64{1}))).All() + require.NoError(t, err) + require.NotNil(t, users) + require.Equal(t, 0, len(users)) +} + +func TestHasManyEmpty(t *testing.T) { + user, err := c.QueryUser().Create(c.ChangeUser().SetName("user")) + require.NoError(t, err) + require.Nil(t, user.UserPosts) + require.Nil(t, user.Posts) + + user, err = c.QueryUser().PreloadUserPosts().Find(user.ID) + require.NoError(t, err) + require.NotNil(t, user.UserPosts) + require.Equal(t, 0, len(user.UserPosts)) + + user, err = c.QueryUser().PreloadPosts().Find(user.ID) + require.NoError(t, err) + require.NotNil(t, user.Posts) + require.NotNil(t, user.UserPosts) + require.Equal(t, 0, len(user.Posts)) + require.Equal(t, 0, len(user.UserPosts)) +} + +func TestHasOne(t *testing.T) { + user, err := c.QueryUser().Create(c.ChangeUser().SetName("has_one")) + require.NoError(t, err) + account, err := c.QueryAccount().Create(c.ChangeAccount().SetName("account").SetUserID(user.ID)) + require.NoError(t, err) + + user, err = c.QueryUser().PreloadAccount().Find(user.ID) + require.NoError(t, err) + require.Equal(t, account.Name, user.Account.Name) +} + func TestPreload(t *testing.T) { user1, _ := c.QueryUser().Create(c.ChangeUser().SetName("user1")) post1, _ := c.QueryPost().Create(c.ChangePost().SetTitle("post1")) @@ -149,11 +254,6 @@ func TestPreload(t *testing.T) { post, _ := c.QueryPost().PreloadUserPosts().Find(post1.ID) require.Equal(t, 1, len(post.UserPosts)) require.Equal(t, userPost1.ID, post.UserPosts[0].ID) - - // preload with zero rows - posts, err := c.QueryPost().Where(c.PostID.GT(1000)).PreloadUserPosts().All() - require.NoError(t, err) - require.Equal(t, 0, len(posts)) } func TestTx(t *testing.T) { diff --git a/internal/integration/postgresql.hcl b/internal/integration/postgresql.hcl index 6e402330..7f176e79 100644 --- a/internal/integration/postgresql.hcl +++ b/internal/integration/postgresql.hcl @@ -1,5 +1,6 @@ database "db" { adapter = "postgresql" + time_zone = "Asia/Shanghai" config "test" { url = "postgres://postgres:postgres@localhost:5432/queryx_test?sslmode=disable" @@ -64,6 +65,9 @@ database "db" { has_many "users" { through = "user_posts" } + belongs_to "author" { + model_name = "User" + } column "title" { type = string @@ -85,6 +89,7 @@ database "db" { model "Account" { belongs_to "user" {} + column "name" { type = string } diff --git a/schema/dsl.go b/schema/dsl.go index 6ae7d2cc..1d69fcd6 100644 --- a/schema/dsl.go +++ b/schema/dsl.go @@ -109,13 +109,19 @@ func (m *Model) AddHasOne(hasOne *HasOne) { } func (m *Model) AddBelongsTo(belongsTo *BelongsTo) { + if belongsTo.ModelName == "" { + belongsTo.ModelName = inflect.Pascal(inflect.Singular(belongsTo.Name)) + } + if belongsTo.ForeignKey == "" { + belongsTo.ForeignKey = fmt.Sprintf("%s_id", belongsTo.Name) + } + m.BelongsTo = append(m.BelongsTo, belongsTo) - // TODO: support foreign key, not null col := &Column{ - Name: fmt.Sprintf("%s_id", belongsTo.Name), + Name: belongsTo.ForeignKey, Type: "bigint", - Null: true, + Null: true, // TODO: support not null } m.Columns = append(m.Columns, col) } diff --git a/schema/hcl.go b/schema/hcl.go index 79999b26..f2e23877 100644 --- a/schema/hcl.go +++ b/schema/hcl.go @@ -113,6 +113,13 @@ var hclHasOne = &hcl.BodySchema{ }, } +var hclBelongsTo = &hcl.BodySchema{ + Attributes: []hcl.AttributeSchema{ + {Name: "model_name"}, + {Name: "foreign_key"}, + }, +} + func (s *Schema) databaseFromBlock(block *hcl.Block, ctx *hcl.EvalContext) (*Database, error) { name := block.Labels[0] database := s.NewDatabase(name) @@ -348,7 +355,7 @@ func indexFromBlock(block *hcl.Block, ctx *hcl.EvalContext) (*Index, error) { func primaryKeyFromBlock(block *hcl.Block, ctx *hcl.EvalContext) (*PrimaryKey, error) { primaryKey := &PrimaryKey{} - content, d := block.Body.Content(hclIndex) + content, d := block.Body.Content(hclPrimaryKey) if d.HasErrors() { return nil, d.Errs()[0] } @@ -372,7 +379,26 @@ func belongsToFromBlock(block *hcl.Block, ctx *hcl.EvalContext) (*BelongsTo, err belongsTo := &BelongsTo{ Name: block.Labels[0], } - belongsTo.ModelName = inflect.Pascal(inflect.Singular(belongsTo.Name)) + + content, d := block.Body.Content(hclBelongsTo) + if d.HasErrors() { + return nil, d.Errs()[0] + } + + for name, attr := range content.Attributes { + value, d := attr.Expr.Value(ctx) + if d.HasErrors() { + return nil, d.Errs()[0] + } + + switch name { + case "model_name": + belongsTo.ModelName = value.AsString() + case "foreign_key": + belongsTo.ForeignKey = value.AsString() + } + } + return belongsTo, nil }