Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
290 lines (266 sloc) 7.46 KB
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !codes
package testutil
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strings"
"github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
)
// CompareUnorderedStringSlice compare two string slices.
// If a and b is exactly the same except the order, it returns true.
// In otherwise return false.
func CompareUnorderedStringSlice(a []string, b []string) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if len(a) != len(b) {
return false
}
m := make(map[string]int, len(a))
for _, i := range a {
_, ok := m[i]
if !ok {
m[i] = 1
} else {
m[i]++
}
}
for _, i := range b {
_, ok := m[i]
if !ok {
return false
}
m[i]--
if m[i] == 0 {
delete(m, i)
}
}
return len(m) == 0
}
// datumEqualsChecker is a checker for DatumEquals.
type datumEqualsChecker struct {
*check.CheckerInfo
}
// DatumEquals checker verifies that the obtained value is equal to
// the expected value.
// For example:
// c.Assert(value, DatumEquals, NewDatum(42))
var DatumEquals check.Checker = &datumEqualsChecker{
&check.CheckerInfo{Name: "DatumEquals", Params: []string{"obtained", "expected"}},
}
func (checker *datumEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
defer func() {
if v := recover(); v != nil {
result = false
error = fmt.Sprint(v)
}
}()
paramFirst, ok := params[0].(types.Datum)
if !ok {
panic("the first param should be datum")
}
paramSecond, ok := params[1].(types.Datum)
if !ok {
panic("the second param should be datum")
}
sc := new(stmtctx.StatementContext)
res, err := paramFirst.CompareDatum(sc, &paramSecond)
if err != nil {
panic(err)
}
return res == 0, ""
}
// RowsWithSep is a convenient function to wrap args to a slice of []interface.
// The arg represents a row, split by sep.
func RowsWithSep(sep string, args ...string) [][]interface{} {
rows := make([][]interface{}, len(args))
for i, v := range args {
strs := strings.Split(v, sep)
row := make([]interface{}, len(strs))
for j, s := range strs {
row[j] = s
}
rows[i] = row
}
return rows
}
// record is a flag used for generate test result.
var record bool
func init() {
flag.BoolVar(&record, "record", false, "to generate test result")
}
type testCases struct {
Name string
Cases *json.RawMessage // For delayed parse.
decodedOut interface{} // For generate output.
}
// TestData stores all the data of a test suite.
type TestData struct {
input []testCases
output []testCases
filePathPrefix string
funcMap map[string]int
}
// LoadTestSuiteData loads test suite data from file.
func LoadTestSuiteData(dir, suiteName string) (res TestData, err error) {
res.filePathPrefix = filepath.Join(dir, suiteName)
res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix))
if err != nil {
return res, err
}
if record {
res.output = make([]testCases, len(res.input), len(res.input))
for i := range res.input {
res.output[i].Name = res.input[i].Name
}
} else {
res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix))
if err != nil {
return res, err
}
if len(res.input) != len(res.output) {
return res, errors.New(fmt.Sprintf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output)))
}
}
res.funcMap = make(map[string]int, len(res.input))
for i, test := range res.input {
res.funcMap[test.Name] = i
if test.Name != res.output[i].Name {
return res, errors.New(fmt.Sprintf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name))
}
}
return res, nil
}
func loadTestSuiteCases(filePath string) (res []testCases, err error) {
jsonFile, err := os.Open(filePath)
if err != nil {
return res, err
}
defer jsonFile.Close()
byteValue, err := ioutil.ReadAll(jsonFile)
if err != nil {
return res, err
}
// Remove comments, since they are not allowed in json.
re := regexp.MustCompile("(?s)//.*?\n")
err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res)
return res, err
}
// GetTestCasesByName gets the test cases for a test function by its name.
func (t *TestData) GetTestCasesByName(caseName string, c *check.C, in interface{}, out interface{}) {
casesIdx, ok := t.funcMap[caseName]
c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", caseName))
err := json.Unmarshal(*t.input[casesIdx].Cases, in)
c.Assert(err, check.IsNil)
if !record {
err = json.Unmarshal(*t.output[casesIdx].Cases, out)
c.Assert(err, check.IsNil)
} else {
// Init for generate output file.
inputLen := reflect.ValueOf(in).Elem().Len()
v := reflect.ValueOf(out).Elem()
if v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
}
}
t.output[casesIdx].decodedOut = out
}
// GetTestCases gets the test cases for a test function.
func (t *TestData) GetTestCases(c *check.C, in interface{}, out interface{}) {
// Extract caller's name.
pc, _, _, ok := runtime.Caller(1)
c.Assert(ok, check.IsTrue)
details := runtime.FuncForPC(pc)
funcNameIdx := strings.LastIndex(details.Name(), ".")
funcName := details.Name()[funcNameIdx+1:]
casesIdx, ok := t.funcMap[funcName]
c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", funcName))
err := json.Unmarshal(*t.input[casesIdx].Cases, in)
c.Assert(err, check.IsNil)
if !record {
err = json.Unmarshal(*t.output[casesIdx].Cases, out)
c.Assert(err, check.IsNil)
} else {
// Init for generate output file.
inputLen := reflect.ValueOf(in).Elem().Len()
v := reflect.ValueOf(out).Elem()
if v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
}
}
t.output[casesIdx].decodedOut = out
}
// OnRecord execute the function to update result.
func (t *TestData) OnRecord(updateFunc func()) {
if record {
updateFunc()
}
}
// ConvertRowsToStrings converts [][]interface{} to []string.
func (t *TestData) ConvertRowsToStrings(rows [][]interface{}) (rs []string) {
for _, row := range rows {
s := fmt.Sprintf("%v", row)
// Trim the leftmost `[` and rightmost `]`.
s = s[1 : len(s)-1]
rs = append(rs, s)
}
return rs
}
// GenerateOutputIfNeeded generate the output file.
func (t *TestData) GenerateOutputIfNeeded() error {
if !record {
return nil
}
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
for i, test := range t.output {
err := enc.Encode(test.decodedOut)
if err != nil {
return err
}
res := make([]byte, len(buf.Bytes()), len(buf.Bytes()))
copy(res, buf.Bytes())
buf.Reset()
rm := json.RawMessage(res)
t.output[i].Cases = &rm
}
err := enc.Encode(t.output)
if err != nil {
return err
}
file, err := os.Create(fmt.Sprintf("%s_out.json", t.filePathPrefix))
if err != nil {
return err
}
defer file.Close()
_, err = file.Write(buf.Bytes())
return err
}
You can’t perform that action at this time.