diff --git a/fmtsort/export_test.go b/fmtsort/export_test.go new file mode 100644 index 00000000..25cbb5d4 --- /dev/null +++ b/fmtsort/export_test.go @@ -0,0 +1,11 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fmtsort + +import "reflect" + +func Compare(a, b reflect.Value) int { + return compare(a, b) +} diff --git a/fmtsort/sort.go b/fmtsort/sort.go new file mode 100644 index 00000000..70a305a3 --- /dev/null +++ b/fmtsort/sort.go @@ -0,0 +1,216 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fmtsort provides a general stable ordering mechanism +// for maps, on behalf of the fmt and text/template packages. +// It is not guaranteed to be efficient and works only for types +// that are valid map keys. +package fmtsort + +import ( + "reflect" + "sort" +) + +// Note: Throughout this package we avoid calling reflect.Value.Interface as +// it is not always legal to do so and it's easier to avoid the issue than to face it. + +// SortedMap represents a map's keys and values. The keys and values are +// aligned in index order: Value[i] is the value in the map corresponding to Key[i]. +type SortedMap struct { + Key []reflect.Value + Value []reflect.Value +} + +func (o *SortedMap) Len() int { return len(o.Key) } +func (o *SortedMap) Less(i, j int) bool { return compare(o.Key[i], o.Key[j]) < 0 } +func (o *SortedMap) Swap(i, j int) { + o.Key[i], o.Key[j] = o.Key[j], o.Key[i] + o.Value[i], o.Value[j] = o.Value[j], o.Value[i] +} + +// Sort accepts a map and returns a SortedMap that has the same keys and +// values but in a stable sorted order according to the keys, modulo issues +// raised by unorderable key values such as NaNs. +// +// The ordering rules are more general than with Go's < operator: +// +// - when applicable, nil compares low +// - ints, floats, and strings order by < +// - NaN compares less than non-NaN floats +// - bool compares false before true +// - complex compares real, then imag +// - pointers compare by machine address +// - channel values compare by machine address +// - structs compare each field in turn +// - arrays compare each element in turn. +// Otherwise identical arrays compare by length. +// - interface values compare first by reflect.Type describing the concrete type +// and then by concrete value as described in the previous rules. +// +func Sort(mapValue reflect.Value) *SortedMap { + if mapValue.Type().Kind() != reflect.Map { + return nil + } + key := make([]reflect.Value, mapValue.Len()) + value := make([]reflect.Value, len(key)) + iter := mapValue.MapRange() + for i := 0; iter.Next(); i++ { + key[i] = iter.Key() + value[i] = iter.Value() + } + sorted := &SortedMap{ + Key: key, + Value: value, + } + sort.Stable(sorted) + return sorted +} + +// compare compares two values of the same type. It returns -1, 0, 1 +// according to whether a > b (1), a == b (0), or a < b (-1). +// If the types differ, it returns -1. +// See the comment on Sort for the comparison rules. +func compare(aVal, bVal reflect.Value) int { + aType, bType := aVal.Type(), bVal.Type() + if aType != bType { + return -1 // No good answer possible, but don't return 0: they're not equal. + } + switch aVal.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + a, b := aVal.Int(), bVal.Int() + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + a, b := aVal.Uint(), bVal.Uint() + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } + case reflect.String: + a, b := aVal.String(), bVal.String() + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } + case reflect.Float32, reflect.Float64: + return floatCompare(aVal.Float(), bVal.Float()) + case reflect.Complex64, reflect.Complex128: + a, b := aVal.Complex(), bVal.Complex() + if c := floatCompare(real(a), real(b)); c != 0 { + return c + } + return floatCompare(imag(a), imag(b)) + case reflect.Bool: + a, b := aVal.Bool(), bVal.Bool() + switch { + case a == b: + return 0 + case a: + return 1 + default: + return -1 + } + case reflect.Ptr: + a, b := aVal.Pointer(), bVal.Pointer() + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } + case reflect.Chan: + if c, ok := nilCompare(aVal, bVal); ok { + return c + } + ap, bp := aVal.Pointer(), bVal.Pointer() + switch { + case ap < bp: + return -1 + case ap > bp: + return 1 + default: + return 0 + } + case reflect.Struct: + for i := 0; i < aVal.NumField(); i++ { + if c := compare(aVal.Field(i), bVal.Field(i)); c != 0 { + return c + } + } + return 0 + case reflect.Array: + for i := 0; i < aVal.Len(); i++ { + if c := compare(aVal.Index(i), bVal.Index(i)); c != 0 { + return c + } + } + return 0 + case reflect.Interface: + if c, ok := nilCompare(aVal, bVal); ok { + return c + } + c := compare(reflect.ValueOf(aVal.Elem().Type()), reflect.ValueOf(bVal.Elem().Type())) + if c != 0 { + return c + } + return compare(aVal.Elem(), bVal.Elem()) + default: + // Certain types cannot appear as keys (maps, funcs, slices), but be explicit. + panic("bad type in compare: " + aType.String()) + } +} + +// nilCompare checks whether either value is nil. If not, the boolean is false. +// If either value is nil, the boolean is true and the integer is the comparison +// value. The comparison is defined to be 0 if both are nil, otherwise the one +// nil value compares low. Both arguments must represent a chan, func, +// interface, map, pointer, or slice. +func nilCompare(aVal, bVal reflect.Value) (int, bool) { + if aVal.IsNil() { + if bVal.IsNil() { + return 0, true + } + return -1, true + } + if bVal.IsNil() { + return 1, true + } + return 0, false +} + +// floatCompare compares two floating-point values. NaNs compare low. +func floatCompare(a, b float64) int { + switch { + case isNaN(a): + return -1 // No good answer if b is a NaN so don't bother checking. + case isNaN(b): + return 1 + case a < b: + return -1 + case a > b: + return 1 + } + return 0 +} + +func isNaN(a float64) bool { + return a != a +} diff --git a/fmtsort/sort_test.go b/fmtsort/sort_test.go new file mode 100644 index 00000000..ac81cc3c --- /dev/null +++ b/fmtsort/sort_test.go @@ -0,0 +1,247 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fmtsort_test + +import ( + "fmt" + "math" + "reflect" + "strings" + "testing" + + "github.com/rogpeppe/go-internal/fmtsort" +) + +var compareTests = [][]reflect.Value{ + ct(reflect.TypeOf(int(0)), -1, 0, 1), + ct(reflect.TypeOf(int8(0)), -1, 0, 1), + ct(reflect.TypeOf(int16(0)), -1, 0, 1), + ct(reflect.TypeOf(int32(0)), -1, 0, 1), + ct(reflect.TypeOf(int64(0)), -1, 0, 1), + ct(reflect.TypeOf(uint(0)), 0, 1, 5), + ct(reflect.TypeOf(uint8(0)), 0, 1, 5), + ct(reflect.TypeOf(uint16(0)), 0, 1, 5), + ct(reflect.TypeOf(uint32(0)), 0, 1, 5), + ct(reflect.TypeOf(uint64(0)), 0, 1, 5), + ct(reflect.TypeOf(uintptr(0)), 0, 1, 5), + ct(reflect.TypeOf(string("")), "", "a", "ab"), + ct(reflect.TypeOf(float32(0)), math.NaN(), math.Inf(-1), -1e10, 0, 1e10, math.Inf(1)), + ct(reflect.TypeOf(float64(0)), math.NaN(), math.Inf(-1), -1e10, 0, 1e10, math.Inf(1)), + ct(reflect.TypeOf(complex64(0+1i)), -1-1i, -1+0i, -1+1i, 0-1i, 0+0i, 0+1i, 1-1i, 1+0i, 1+1i), + ct(reflect.TypeOf(complex128(0+1i)), -1-1i, -1+0i, -1+1i, 0-1i, 0+0i, 0+1i, 1-1i, 1+0i, 1+1i), + ct(reflect.TypeOf(false), false, true), + ct(reflect.TypeOf(&ints[0]), &ints[0], &ints[1], &ints[2]), + ct(reflect.TypeOf(chans[0]), chans[0], chans[1], chans[2]), + ct(reflect.TypeOf(toy{}), toy{0, 1}, toy{0, 2}, toy{1, -1}, toy{1, 1}), + ct(reflect.TypeOf([2]int{}), [2]int{1, 1}, [2]int{1, 2}, [2]int{2, 0}), + ct(reflect.TypeOf(interface{}(interface{}(0))), iFace, 1, 2, 3), +} + +var iFace interface{} + +func ct(typ reflect.Type, args ...interface{}) []reflect.Value { + value := make([]reflect.Value, len(args)) + for i, v := range args { + x := reflect.ValueOf(v) + if !x.IsValid() { // Make it a typed nil. + x = reflect.Zero(typ) + } else { + x = x.Convert(typ) + } + value[i] = x + } + return value +} + +func TestCompare(t *testing.T) { + for _, test := range compareTests { + for i, v0 := range test { + for j, v1 := range test { + c := fmtsort.Compare(v0, v1) + var expect int + switch { + case i == j: + expect = 0 + // NaNs are tricky. + if typ := v0.Type(); (typ.Kind() == reflect.Float32 || typ.Kind() == reflect.Float64) && math.IsNaN(v0.Float()) { + expect = -1 + } + case i < j: + expect = -1 + case i > j: + expect = 1 + } + if c != expect { + t.Errorf("%s: compare(%v,%v)=%d; expect %d", v0.Type(), v0, v1, c, expect) + } + } + } + } +} + +type sortTest struct { + data interface{} // Always a map. + print string // Printed result using our custom printer. +} + +var sortTests = []sortTest{ + { + map[int]string{7: "bar", -3: "foo"}, + "-3:foo 7:bar", + }, + { + map[uint8]string{7: "bar", 3: "foo"}, + "3:foo 7:bar", + }, + { + map[string]string{"7": "bar", "3": "foo"}, + "3:foo 7:bar", + }, + { + map[float64]string{7: "bar", -3: "foo", math.NaN(): "nan", math.Inf(0): "inf"}, + "NaN:nan -3:foo 7:bar +Inf:inf", + }, + { + map[complex128]string{7 + 2i: "bar2", 7 + 1i: "bar", -3: "foo", complex(math.NaN(), 0i): "nan", complex(math.Inf(0), 0i): "inf"}, + "(NaN+0i):nan (-3+0i):foo (7+1i):bar (7+2i):bar2 (+Inf+0i):inf", + }, + { + map[bool]string{true: "true", false: "false"}, + "false:false true:true", + }, + { + chanMap(), + "CHAN0:0 CHAN1:1 CHAN2:2", + }, + { + pointerMap(), + "PTR0:0 PTR1:1 PTR2:2", + }, + { + map[toy]string{toy{7, 2}: "72", toy{7, 1}: "71", toy{3, 4}: "34"}, + "{3 4}:34 {7 1}:71 {7 2}:72", + }, + { + map[[2]int]string{{7, 2}: "72", {7, 1}: "71", {3, 4}: "34"}, + "[3 4]:34 [7 1]:71 [7 2]:72", + }, +} + +func sprint(data interface{}) string { + om := fmtsort.Sort(reflect.ValueOf(data)) + if om == nil { + return "nil" + } + b := new(strings.Builder) + for i, key := range om.Key { + if i > 0 { + b.WriteRune(' ') + } + b.WriteString(sprintKey(key)) + b.WriteRune(':') + b.WriteString(fmt.Sprint(om.Value[i])) + } + return b.String() +} + +// sprintKey formats a reflect.Value but gives reproducible values for some +// problematic types such as pointers. Note that it only does special handling +// for the troublesome types used in the test cases; it is not a general +// printer. +func sprintKey(key reflect.Value) string { + switch str := key.Type().String(); str { + case "*int": + ptr := key.Interface().(*int) + for i := range ints { + if ptr == &ints[i] { + return fmt.Sprintf("PTR%d", i) + } + } + return "PTR???" + case "chan int": + c := key.Interface().(chan int) + for i := range chans { + if c == chans[i] { + return fmt.Sprintf("CHAN%d", i) + } + } + return "CHAN???" + default: + return fmt.Sprint(key) + } +} + +var ( + ints [3]int + chans = [3]chan int{make(chan int), make(chan int), make(chan int)} +) + +func pointerMap() map[*int]string { + m := make(map[*int]string) + for i := 2; i >= 0; i-- { + m[&ints[i]] = fmt.Sprint(i) + } + return m +} + +func chanMap() map[chan int]string { + m := make(map[chan int]string) + for i := 2; i >= 0; i-- { + m[chans[i]] = fmt.Sprint(i) + } + return m +} + +type toy struct { + A int // Exported. + b int // Unexported. +} + +func TestOrder(t *testing.T) { + for _, test := range sortTests { + got := sprint(test.data) + if got != test.print { + t.Errorf("%s: got %q, want %q", reflect.TypeOf(test.data), got, test.print) + } + } +} + +func TestInterface(t *testing.T) { + // A map containing multiple concrete types should be sorted by type, + // then value. However, the relative ordering of types is unspecified, + // so test this by checking the presence of sorted subgroups. + m := map[interface{}]string{ + [2]int{1, 0}: "", + [2]int{0, 1}: "", + true: "", + false: "", + 3.1: "", + 2.1: "", + 1.1: "", + math.NaN(): "", + 3: "", + 2: "", + 1: "", + "c": "", + "b": "", + "a": "", + struct{ x, y int }{1, 0}: "", + struct{ x, y int }{0, 1}: "", + } + got := sprint(m) + typeGroups := []string{ + "NaN: 1.1: 2.1: 3.1:", // float64 + "false: true:", // bool + "1: 2: 3:", // int + "a: b: c:", // string + "[0 1]: [1 0]:", // [2]int + "{0 1}: {1 0}:", // struct{ x int; y int } + } + for _, g := range typeGroups { + if !strings.Contains(got, g) { + t.Errorf("sorted map should contain %q", g) + } + } +}