Skip to content

Commit

Permalink
Merge pull request #36 from phelmkamp/sort_func
Browse files Browse the repository at this point in the history
Sort func
  • Loading branch information
phelmkamp committed Dec 31, 2019
2 parents dec7cf3 + 91d2d09 commit 5ecc577
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 50 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ Uses value receiver by default.

`sort` (slice only)

Generates `Len` and `Swap` methods to implement [sort.Interface](https://golang.org/pkg/sort/#Interface), along with a `Sort` convenience method. Include the `stringer` option to generate a `Less` method that compares elements by their string representations. Otherwise, a `Less` method must be implemented separately.
Generates `Len` and `Swap` methods to implement [sort.Interface](https://golang.org/pkg/sort/#Interface), along with a `Sort` convenience method.
A `Less` method can be implemented separately or generated using one of the options.
Uses value receivers by default.

Options
* `stringer`: generate a `Less` method that compares elements by their string representations
* `func`: generate a `Less` method that accepts a less function

`wrapper` (slice only)

Indicates that the struct is a "wrapper" for the given slice. Enables `omitfield` and `chain` options for all subsequent directives.
Expand Down
76 changes: 70 additions & 6 deletions directive/directive.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
const (
optOmitField = "omitfield"
optStringer = "stringer"
optFunc = "func"
optReflect = "reflect"
optChain = "chain"
)
Expand Down Expand Up @@ -221,32 +222,95 @@ func sort(tgt *Target, opts []string) {

log.Println("Adding method: Len")
log.Println("Adding method: Swap")
log.Println("Adding method: Sort")
sort := meta.Method{
lenSwap := meta.Method{
RcvName: tgt.RcvName,
RcvType: tgt.RcvType,
FldName: fldNm,
Tmpl: "sort",
Tmpl: "len_swap",
}
tgt.MetaFile.Methods = append(tgt.MetaFile.Methods, &sort)
tgt.MetaFile.Methods = append(tgt.MetaFile.Methods, &lenSwap)

var isStringer bool
var isStringer, isFunc bool
for i := range opts {
if opts[i] == optFunc {
isFunc = true
break
}
if opts[i] == optStringer {
isStringer = true
break
}
}

if isFunc {
elemType := strings.TrimPrefix(tgt.FldType, "[]")
lesserNm := lowerFirst(tgt.RcvType) + "Lesser"

log.Println("Adding type: " + lesserNm)
lesser := meta.Type{
Name: lesserNm,
Embed: tgt.RcvType,
Misc: map[string]interface{}{
"ElemType": elemType,
},
Tmpl: "type_lesser",
}
tgt.MetaFile.Types = append(tgt.MetaFile.Types, lesser)

log.Println("Adding method: Less")
less := meta.Method{
RcvName: tgt.RcvName,
RcvType: lesserNm,
FldName: fldNm,
Misc: map[string]interface{}{
"RetStmt": fmt.Sprintf(
"return %s.less(%s.%s[i], %s.%s[j])",
tgt.RcvName, tgt.RcvName, fldNm, tgt.RcvName, fldNm,
),
},
Tmpl: "less",
}
tgt.MetaFile.Methods = append(tgt.MetaFile.Methods, &less)

log.Println("Adding method: Sort")
sort := meta.Method{
RcvName: tgt.RcvName,
RcvType: tgt.RcvType,
ArgType: elemType,
Misc: map[string]interface{}{
"Lesser": lesserNm,
},
Tmpl: "sort_func",
}
tgt.MetaFile.Methods = append(tgt.MetaFile.Methods, &sort)
return
}

if isStringer {
log.Println("Adding method: Less")
less := meta.Method{
RcvName: tgt.RcvName,
RcvType: tgt.RcvType,
FldName: fldNm,
Tmpl: "less",
Misc: map[string]interface{}{
"RetStmt": fmt.Sprintf(
"return %s.%s[i].String() < %s.%s[j].String()",
tgt.RcvName, fldNm, tgt.RcvName, fldNm,
),
},
Tmpl: "less",
}
tgt.MetaFile.Methods = append(tgt.MetaFile.Methods, &less)
}

log.Println("Adding method: Sort")
sort := meta.Method{
RcvName: tgt.RcvName,
RcvType: tgt.RcvType,
FldName: fldNm,
Tmpl: "sort",
}
tgt.MetaFile.Methods = append(tgt.MetaFile.Methods, &sort)
}

// stringer adds each name of the given field to the String() implementation.
Expand Down
2 changes: 1 addition & 1 deletion internal/testdata/foobar/foo_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
package foobar

import (
"fmt"
"reflect"
"time"
"fmt"
)

// NewFoo creates a new Foo with the given initial values.
Expand Down
2 changes: 1 addition & 1 deletion internal/testdata/person/person.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ type Person struct {
}

type Persons struct {
result []Person `meta:"wrapper;new;filter;mapper,int;sort,stringer;getter"`
result []Person `meta:"wrapper;new;filter;mapper,int;sort,func;getter"`
}
24 changes: 16 additions & 8 deletions internal/testdata/person/person_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import (
"sort"
)

type personsLesser struct {
Persons
less func(vi, vj Person) bool
}

// String returns the "native" format of Person. Implements the fmt.Stringer interface.
func (p Person) String() string {
return fmt.Sprintf("%v", p.Name)
Expand Down Expand Up @@ -63,16 +68,19 @@ func (p Persons) Swap(i, j int) {
p.result[i], p.result[j] = p.result[j], p.result[i]
}

// Sort is a convenience method.
func (p Persons) Sort() Persons {
sort.Sort(p)
return p
}

// Less reports whether the element with
// index i should sort before the element with index j.
func (p Persons) Less(i, j int) bool {
return p.result[i].String() < p.result[j].String()
func (p personsLesser) Less(i, j int) bool {
return p.less(p.result[i], p.result[j])
}

// Sort sorts the collection using the given less function.
func (p Persons) Sort(less func(vi, vj Person) bool) Persons {
sort.Sort(personsLesser{
Persons: p,
less: less,
})
return p
}

// Result returns the value of result.
Expand Down
8 changes: 6 additions & 2 deletions internal/testdata/person/person_meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ func Example() {
return p.Name != "Bob"
}).
// sort by name
Sort().
Sort(func(vi, vj Person) bool {
return vi.Name < vj.Name
}).
// map to ages
MapToInt(func(p Person) int {
return time.Now().Year() - p.Birthdate.Year()
Expand Down Expand Up @@ -79,7 +81,9 @@ func TestPersons_Sort(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.p.Sort()
tt.p.Sort(func(vi, vj Person) bool {
return vi.Name < vj.Name
})
if !reflect.DeepEqual(tt.p.Result(), tt.want) {
t.Errorf("got = %v, want %v", tt.p.Result(), tt.want)
}
Expand Down
68 changes: 50 additions & 18 deletions meta/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ import (

const (
topComment = "// GENERATED BY metatag, DO NOT EDIT\n// (or edit away - I'm a comment, not a cop)\n\n"
fileTemplate = "package %s\n%s%s"
fileTemplate = "package %s\n%s%s%s"
)

// File represents a generated code file
type File struct {
Package string
Imports Imports
Types Types
Methods Methods `meta:"ptr;filter"`
}

Expand All @@ -32,7 +33,7 @@ func NewFile(pkg string) *File {

// String generates the file content
func (f *File) String() string {
return topComment + fmt.Sprintf(fileTemplate, f.Package, f.Imports, f.Methods)
return topComment + fmt.Sprintf(fileTemplate, f.Package, f.Imports, f.Types, f.Methods)
}

// Imports represents a set of import paths
Expand All @@ -55,6 +56,33 @@ func (is Imports) String() string {
return sb.String()
}

// Type represents a type declaration
type Type struct {
Name string
Embed string
Misc map[string]interface{}
Tmpl string
}

// String generates the type code
func (t Type) String() string {
return executeTmpl(t.Tmpl, t)
}

// Types represents a set of types
type Types []Type

// String generates the code for all types
func (ts Types) String() string {
sb := strings.Builder{}
for i := range ts {
sb.WriteString("\n")
sb.WriteString(ts[i].String())
sb.WriteString("\n")
}
return sb.String()
}

// Method represents a generated method
type Method struct {
RcvName, RcvType string
Expand All @@ -68,22 +96,7 @@ type Method struct {

// String generates the method code
func (m Method) String() string {
tmplBytes, err := templates.Asset(m.Tmpl + ".tmpl")
if err != nil {
log.Fatal(err)
}

tmplMessage, err := template.New(m.Tmpl).Parse(string(tmplBytes))
if err != nil {
log.Fatal(err)
}

var buf bytes.Buffer
if err := tmplMessage.Execute(&buf, m); err != nil {
log.Fatal(err)
}

return buf.String()
return executeTmpl(m.Tmpl, m)
}

// Methods represents a collection of generated methods
Expand All @@ -99,3 +112,22 @@ func (ms Methods) String() string {
}
return sb.String()
}

func executeTmpl(tmpl string, data interface{}) string {
tmplBytes, err := templates.Asset(tmpl + ".tmpl")
if err != nil {
log.Fatal(err)
}

tmplMessage, err := template.New(tmpl).Parse(string(tmplBytes))
if err != nil {
log.Fatal(err)
}

var buf bytes.Buffer
if err := tmplMessage.Execute(&buf, data); err != nil {
log.Fatal(err)
}

return buf.String()
}
Loading

0 comments on commit 5ecc577

Please sign in to comment.