diff --git a/pkg/client/client.go b/pkg/client/client.go index e283ccc50..444aca423 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -17,11 +17,12 @@ package client +// Client represents the interface of http/dubbo clients type Client interface { Init() error Close() error Call(req *Request) (resp Response, err error) - // MappingParams mapping param, uri, query, body ... - MappingParams(req *Request) (types []string, reqData []interface{}, err error) + // MapParams mapping param, uri, query, body ... + MapParams(req *Request) (reqData interface{}, err error) } diff --git a/pkg/client/dubbo/dubbo.go b/pkg/client/dubbo/dubbo.go index a5c341a57..11805c555 100644 --- a/pkg/client/dubbo/dubbo.go +++ b/pkg/client/dubbo/dubbo.go @@ -19,7 +19,6 @@ package dubbo import ( "context" - "reflect" "strings" "sync" "time" @@ -29,7 +28,6 @@ import ( "github.com/apache/dubbo-go/common/constant" dg "github.com/apache/dubbo-go/config" "github.com/apache/dubbo-go/protocol/dubbo" - "github.com/pkg/errors" ) import ( @@ -44,12 +42,6 @@ const ( JavaLangClassName = "java.lang.Long" ) -var mappers = map[string]client.ParamMapper{ - "queryStrings": queryStringsMapper{}, - "headers": headerMapper{}, - "requestBody": bodyMapper{}, -} - var ( dubboClient *Client onceClient = sync.Once{} @@ -140,11 +132,11 @@ func (dc *Client) Close() error { // Call invoke service func (dc *Client) Call(req *client.Request) (resp client.Response, err error) { dm := req.API.Method.IntegrationRequest - types, values, err := dc.MappingParams(req) + types := req.API.IntegrationRequest.ParamTypes + values, err := dc.MapParams(req) if err != nil { return *client.EmptyResponse, err } - method := dm.Method logger.Debugf("[dubbo-go-proxy] invoke, method:%s, types:%s, reqData:%v", method, types, values) @@ -165,22 +157,22 @@ func (dc *Client) Call(req *client.Request) (resp client.Response, err error) { return *NewDubboResponse(rst), nil } -// MappingParams param mapping to api. -func (dc *Client) MappingParams(req *client.Request) ([]string, []interface{}, error) { +// MapParams param mapping to api. +func (dc *Client) MapParams(req *client.Request) (interface{}, error) { r := req.API.Method.IntegrationRequest var values []interface{} for _, mappingParam := range r.MappingParams { source, _, err := client.ParseMapSource(mappingParam.Name) if err != nil { - return nil, nil, err + return nil, err } if mapper, ok := mappers[source]; ok { if err := mapper.Map(mappingParam, *req, &values); err != nil { - return nil, nil, err + return nil, err } } } - return req.API.IntegrationRequest.ParamTypes, values, nil + return values, nil } func (dc *Client) get(key string) *dg.GenericService { @@ -246,30 +238,3 @@ func (dc *Client) create(key string, irequest config.IntegrationRequest) *dg.Gen dc.GenericServicePool[key] = clientService return clientService } - -func validateTarget(target interface{}) (reflect.Value, error) { - rv := reflect.ValueOf(target) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return rv, errors.New("Target params must be a non-nil pointer") - } - if _, ok := target.(*[]interface{}); !ok { - return rv, errors.New("Target params for dubbo backend must be *[]interface{}") - } - return rv, nil -} - -func setTarget(rv reflect.Value, pos int, value interface{}) { - if rv.Kind() != reflect.Ptr && rv.Type().Name() != "" && rv.CanAddr() { - rv = rv.Addr() - } else { - rv = rv.Elem() - } - - tempValue := rv.Interface().([]interface{}) - if len(tempValue) <= pos { - list := make([]interface{}, pos+1-len(tempValue)) - tempValue = append(tempValue, list...) - } - tempValue[pos] = value - rv.Set(reflect.ValueOf(tempValue)) -} diff --git a/pkg/client/dubbo/dubbo_test.go b/pkg/client/dubbo/dubbo_test.go index b51fdadc5..e29735f34 100644 --- a/pkg/client/dubbo/dubbo_test.go +++ b/pkg/client/dubbo/dubbo_test.go @@ -104,10 +104,10 @@ func TestMappingParams(t *testing.T) { }, } req := client.NewReq(context.TODO(), r, api) - _, params, err := dClient.MappingParams(req) + params, err := dClient.MapParams(req) assert.Nil(t, err) - assert.Equal(t, params[0], "12345") - assert.Equal(t, params[1], "19") + assert.Equal(t, params.([]interface{})[0], "12345") + assert.Equal(t, params.([]interface{})[1], "19") r, _ = http.NewRequest("GET", "/mock/test?id=12345&age=19", bytes.NewReader([]byte(""))) api = mock.GetMockAPI(config.MethodGet, "/mock/test") @@ -127,11 +127,11 @@ func TestMappingParams(t *testing.T) { } r.Header.Set("Auth", "1234567") req = client.NewReq(context.TODO(), r, api) - _, params, err = dClient.MappingParams(req) + params, err = dClient.MapParams(req) assert.Nil(t, err) - assert.Equal(t, params[0], "12345") - assert.Equal(t, params[1], "19") - assert.Equal(t, params[2], "1234567") + assert.Equal(t, params.([]interface{})[0], "12345") + assert.Equal(t, params.([]interface{})[1], "19") + assert.Equal(t, params.([]interface{})[2], "1234567") r, _ = http.NewRequest("POST", "/mock/test?id=12345&age=19", bytes.NewReader([]byte(`{"sex": "male", "name":{"firstName": "Joe", "lastName": "Biden"}}`))) api = mock.GetMockAPI(config.MethodGet, "/mock/test") @@ -159,46 +159,11 @@ func TestMappingParams(t *testing.T) { } r.Header.Set("Auth", "1234567") req = client.NewReq(context.TODO(), r, api) - _, params, err = dClient.MappingParams(req) + params, err = dClient.MapParams(req) assert.Nil(t, err) - assert.Equal(t, params[0], "12345") - assert.Equal(t, params[1], "19") - assert.Equal(t, params[2], "1234567") - assert.Equal(t, params[3], "male") - assert.Equal(t, params[4], "Joe") -} - -func TestValidateTarget(t *testing.T) { - target := []interface{}{} - val, err := validateTarget(&target) - assert.Nil(t, err) - assert.NotNil(t, val) - _, err = validateTarget(target) - assert.EqualError(t, err, "Target params must be a non-nil pointer") - target2 := "" - _, err = validateTarget(&target2) - assert.EqualError(t, err, "Target params for dubbo backend must be *[]interface{}") -} - -func TestParseMapSource(t *testing.T) { - from, key, err := client.ParseMapSource("queryStrings.id") - assert.Nil(t, err) - assert.Equal(t, from, "queryStrings") - assert.Equal(t, key[0], "id") - - from, key, err = client.ParseMapSource("headers.id") - assert.Nil(t, err) - assert.Equal(t, from, "headers") - assert.Equal(t, key[0], "id") - - from, key, err = client.ParseMapSource("requestBody.user.id") - assert.Nil(t, err) - assert.Equal(t, from, "requestBody") - assert.Equal(t, key[0], "user") - assert.Equal(t, key[1], "id") - - from, key, err = client.ParseMapSource("what.user.id") - assert.EqualError(t, err, "Parameter mapping config incorrect. Please fix it") - from, key, err = client.ParseMapSource("requestBody.*userid") - assert.EqualError(t, err, "Parameter mapping config incorrect. Please fix it") + assert.Equal(t, params.([]interface{})[0], "12345") + assert.Equal(t, params.([]interface{})[1], "19") + assert.Equal(t, params.([]interface{})[2], "1234567") + assert.Equal(t, params.([]interface{})[3], "male") + assert.Equal(t, params.([]interface{})[4], "Joe") } diff --git a/pkg/client/dubbo/mapper.go b/pkg/client/dubbo/mapper.go index 27a293757..d43208e27 100644 --- a/pkg/client/dubbo/mapper.go +++ b/pkg/client/dubbo/mapper.go @@ -21,6 +21,7 @@ import ( "bytes" "encoding/json" "io/ioutil" + "net/url" "reflect" "strconv" ) @@ -35,6 +36,12 @@ import ( "github.com/dubbogo/dubbo-go-proxy/pkg/config" ) +var mappers = map[string]client.ParamMapper{ + constant.QueryStrings: queryStringsMapper{}, + constant.Headers: headerMapper{}, + constant.RequestBody: bodyMapper{}, +} + type queryStringsMapper struct{} func (qm queryStringsMapper) Map(mp config.MappingParam, c client.Request, target interface{}) error { @@ -42,7 +49,10 @@ func (qm queryStringsMapper) Map(mp config.MappingParam, c client.Request, targe if err != nil { return err } - c.IngressRequest.ParseForm() + queryValues, err := url.ParseQuery(c.IngressRequest.URL.RawQuery) + if err != nil { + return errors.Wrap(err, "Error happened when parsing the query paramters") + } _, key, err := client.ParseMapSource(mp.Name) if err != nil { return err @@ -51,12 +61,12 @@ func (qm queryStringsMapper) Map(mp config.MappingParam, c client.Request, targe if err != nil { return errors.Errorf("Parameter mapping %v incorrect", mp) } - formValue := c.IngressRequest.Form.Get(key[0]) - if len(formValue) == 0 { + qValue := queryValues.Get(key[0]) + if len(qValue) == 0 { return errors.Errorf("Query parameter %s does not exist", key) } - setTarget(rv, pos, formValue) + setTarget(rv, pos, qValue) return nil } @@ -84,6 +94,7 @@ func (hm headerMapper) Map(mp config.MappingParam, c client.Request, target inte type bodyMapper struct{} func (bm bodyMapper) Map(mp config.MappingParam, c client.Request, target interface{}) error { + // TO-DO: add support for content-type other than application/json rv, err := validateTarget(target) if err != nil { return err @@ -106,28 +117,38 @@ func (bm bodyMapper) Map(mp config.MappingParam, c client.Request, target interf if err != nil { return err } - - val, err := getMapValue(mapBody, keys) + val, err := client.GetMapValue(mapBody, keys) setTarget(rv, pos, val) c.IngressRequest.Body = ioutil.NopCloser(bytes.NewBuffer(rawBody)) return nil } -func getMapValue(sourceMap map[string]interface{}, keys []string) (interface{}, error) { - if len(keys) == 1 && keys[0] == constant.DefaultBodyAll { - return sourceMap, nil - } - for i, key := range keys { - _, ok := sourceMap[key] - if !ok { - return nil, errors.Errorf("%s does not exist in request body", key) - } - rvalue := reflect.ValueOf(sourceMap[key]) - if rvalue.Type().Kind() != reflect.Map { - return rvalue.Interface(), nil - } - return getMapValue(sourceMap[key].(map[string]interface{}), keys[i+1:]) - } - return nil, nil +// validateTarget verify if the incoming target for the Map function +// can be processed as expected. +func validateTarget(target interface{}) (reflect.Value, error) { + rv := reflect.ValueOf(target) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return rv, errors.New("Target params must be a non-nil pointer") + } + if _, ok := target.(*[]interface{}); !ok { + return rv, errors.New("Target params for dubbo backend must be *[]interface{}") + } + return rv, nil +} + +func setTarget(rv reflect.Value, pos int, value interface{}) { + if rv.Kind() != reflect.Ptr && rv.Type().Name() != "" && rv.CanAddr() { + rv = rv.Addr() + } else { + rv = rv.Elem() + } + + tempValue := rv.Interface().([]interface{}) + if len(tempValue) <= pos { + list := make([]interface{}, pos+1-len(tempValue)) + tempValue = append(tempValue, list...) + } + tempValue[pos] = value + rv.Set(reflect.ValueOf(tempValue)) } diff --git a/pkg/client/dubbo/mapper_test.go b/pkg/client/dubbo/mapper_test.go index 5158fe000..7db815b1c 100644 --- a/pkg/client/dubbo/mapper_test.go +++ b/pkg/client/dubbo/mapper_test.go @@ -122,6 +122,10 @@ func TestBodyMapper(t *testing.T) { Name: "requestBody.name.lastName", MapTo: "1", }, + { + Name: "requestBody.name", + MapTo: "2", + }, } bm := bodyMapper{} target := []interface{}{} @@ -134,4 +138,20 @@ func TestBodyMapper(t *testing.T) { err = bm.Map(api.IntegrationRequest.MappingParams[1], *req, &target) assert.Nil(t, err) assert.Equal(t, target[1], "Biden") + + err = bm.Map(api.IntegrationRequest.MappingParams[2], *req, &target) + assert.Nil(t, err) + assert.Equal(t, target[2], map[string]interface{}(map[string]interface{}{"firstName": "Joe", "lastName": "Biden"})) +} + +func TestValidateTarget(t *testing.T) { + target := []interface{}{} + val, err := validateTarget(&target) + assert.Nil(t, err) + assert.NotNil(t, val) + _, err = validateTarget(target) + assert.EqualError(t, err, "Target params must be a non-nil pointer") + target2 := "" + _, err = validateTarget(&target2) + assert.EqualError(t, err, "Target params for dubbo backend must be *[]interface{}") } diff --git a/pkg/client/http/http.go b/pkg/client/http/http.go index e53aaf3bd..638f65b49 100644 --- a/pkg/client/http/http.go +++ b/pkg/client/http/http.go @@ -34,6 +34,7 @@ import ( import ( "github.com/dubbogo/dubbo-go-proxy/pkg/client" + "github.com/pkg/errors" ) // RestMetadata dubbo metadata, api config @@ -99,18 +100,52 @@ func (dc *Client) Call(r *client.Request) (resp client.Response, err error) { urlStr := r.API.IntegrationRequest.HTTPBackendConfig.Protocol + "://" + r.API.IntegrationRequest.HTTPBackendConfig.TargetURL httpClient := &http.Client{Timeout: 5 * time.Second} - request := r.IngressRequest.Clone(context.Background()) - request.URL, _ = url.ParseRequestURI(urlStr) - //TODO header replace, url rewrite.... + request := r.IngressRequest.Clone(r.Context) + //Map the origin paramters to backend parameters according to the API configure + transformedParams, err := dc.MapParams(r) + if err != nil { + return *client.EmptyResponse, err + } + params, _ := transformedParams.(*requestParams) + request.Body = params.Body + request.Header = params.Header + urlStr = strings.TrimRight(urlStr, "/") + "?" + params.Query.Encode() + request.URL, _ = url.ParseRequestURI(urlStr) tmpRet, err := httpClient.Do(request) ret := client.Response{Data: tmpRet} return ret, err } -// MappingParams param mapping to api. -func (dc *Client) MappingParams(req *client.Request) (types []string, reqData []interface{}, err error) { - return nil, nil, nil +// MapParams param mapping to api. +func (dc *Client) MapParams(req *client.Request) (reqData interface{}, err error) { + mp := req.API.IntegrationRequest.MappingParams + r := newRequestParams() + if len(mp) == 0 { + r.Body, err = req.IngressRequest.GetBody() + if err != nil { + return nil, errors.New("Retrieve request body failed") + } + r.Header = req.IngressRequest.Header.Clone() + queryValues, err := url.ParseQuery(req.IngressRequest.URL.RawQuery) + if err != nil { + return nil, errors.New("Retrieve request query parameters failed") + } + r.Query = queryValues + return r, nil + } + for i := 0; i < len(mp); i++ { + source, _, err := client.ParseMapSource(mp[i].Name) + if err != nil { + return nil, err + } + if mapper, ok := mappers[source]; ok { + if err := mapper.Map(mp[i], *req, r); err != nil { + return nil, err + } + } + } + return r, nil } func (dc *Client) get(key string) *dg.GenericService { diff --git a/pkg/client/http/http_test.go b/pkg/client/http/http_test.go new file mode 100644 index 000000000..b9adbf9e2 --- /dev/null +++ b/pkg/client/http/http_test.go @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "bytes" + "context" + "io/ioutil" + "net/http" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/dubbogo/dubbo-go-proxy/pkg/client" + "github.com/dubbogo/dubbo-go-proxy/pkg/common/mock" + "github.com/dubbogo/dubbo-go-proxy/pkg/config" +) + +func TestMapParams(t *testing.T) { + hClient := NewHTTPClient() + r, _ := http.NewRequest("POST", "/mock/test?team=theBoys", bytes.NewReader([]byte("{\"id\":\"12345\",\"age\":\"19\",\"testStruct\":{\"name\":\"mock\",\"test\":\"happy\",\"nickName\":\"trump\"}}"))) + r.Header.Set("Auth", "12345") + api := mock.GetMockAPI(config.MethodGet, "/mock/test") + req := client.NewReq(context.TODO(), r, api) + + val, err := hClient.MapParams(req) + assert.Nil(t, err) + p, _ := val.(*requestParams) + assert.Equal(t, p.Query.Encode(), "team=theBoys") + assert.Equal(t, p.Header.Get("Auth"), "12345") + rawBody, err := ioutil.ReadAll(p.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"id\":\"12345\",\"age\":\"19\",\"testStruct\":{\"name\":\"mock\",\"test\":\"happy\",\"nickName\":\"trump\"}}") + + api.IntegrationRequest.MappingParams = []config.MappingParam{ + { + Name: "queryStrings.team", + MapTo: "queryStrings.team", + }, + { + Name: "requestBody.id", + MapTo: "headers.Id", + }, + { + Name: "headers.Auth", + MapTo: "queryStrings.auth", + }, + { + Name: "requestBody.age", + MapTo: "requestBody.age", + }, + { + Name: "requestBody.testStruct", + MapTo: "requestBody.testStruct", + }, + { + Name: "requestBody.testStruct.nickName", + MapTo: "requestBody.nickName", + }, + } + api.IntegrationRequest.HTTPBackendConfig.Protocol = "https" + api.IntegrationRequest.HTTPBackendConfig.TargetURL = "localhost" + req = client.NewReq(context.TODO(), r, api) + val, err = hClient.MapParams(req) + assert.Nil(t, err) + p, _ = val.(*requestParams) + assert.Equal(t, p.Header.Get("Id"), "12345") + assert.Equal(t, p.Query.Get("auth"), "12345") + assert.Equal(t, p.Query.Get("team"), "theBoys") + rawBody, err = ioutil.ReadAll(p.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"age\":\"19\",\"nickName\":\"trump\",\"testStruct\":{\"name\":\"mock\",\"nickName\":\"trump\",\"test\":\"happy\"}}") + + hClient.Call(req) +} diff --git a/pkg/client/http/mapper.go b/pkg/client/http/mapper.go new file mode 100644 index 000000000..1846173ea --- /dev/null +++ b/pkg/client/http/mapper.go @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "bytes" + "encoding/json" + "io" + "io/ioutil" + "net/http" + "net/url" + "reflect" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/dubbogo/dubbo-go-proxy/pkg/client" + "github.com/dubbogo/dubbo-go-proxy/pkg/common/constant" + "github.com/dubbogo/dubbo-go-proxy/pkg/config" +) + +var mappers = map[string]client.ParamMapper{ + constant.QueryStrings: queryStringsMapper{}, + constant.Headers: headerMapper{}, + constant.RequestBody: bodyMapper{}, +} + +type requestParams struct { + Header http.Header + Query url.Values + Body io.ReadCloser +} + +func newRequestParams() *requestParams { + return &requestParams{ + Header: http.Header{}, + Query: url.Values{}, + Body: ioutil.NopCloser(bytes.NewReader([]byte(""))), + } +} + +type queryStringsMapper struct{} + +func (qs queryStringsMapper) Map(mp config.MappingParam, c client.Request, rawTarget interface{}) error { + rv, err := validateTarget(rawTarget) + if err != nil { + return err + } + _, fromKey, err := client.ParseMapSource(mp.Name) + if err != nil { + return err + } + to, toKey, err := client.ParseMapSource(mp.MapTo) + if err != nil { + return err + } + queryValues, err := url.ParseQuery(c.IngressRequest.URL.RawQuery) + if err != nil { + return errors.Wrap(err, "Error happened when parsing the query paramters") + } + rawValue := queryValues.Get(fromKey[0]) + if len(rawValue) == 0 { + return errors.Errorf("%s in query parameters not found", fromKey[0]) + } + setTarget(rv, to, toKey[0], rawValue) + return nil +} + +type headerMapper struct{} + +func (hm headerMapper) Map(mp config.MappingParam, c client.Request, rawTarget interface{}) error { + target, err := validateTarget(rawTarget) + if err != nil { + return err + } + _, fromKey, err := client.ParseMapSource(mp.Name) + if err != nil { + return err + } + to, toKey, err := client.ParseMapSource(mp.MapTo) + if err != nil { + return err + } + + rawHeader := c.IngressRequest.Header.Get(fromKey[0]) + if len(rawHeader) == 0 { + return errors.Errorf("Header %s not found", fromKey[0]) + } + setTarget(target, to, toKey[0], rawHeader) + return nil +} + +type bodyMapper struct{} + +func (bm bodyMapper) Map(mp config.MappingParam, c client.Request, rawTarget interface{}) error { + // TO-DO: add support for content-type other than application/json + target, err := validateTarget(rawTarget) + if err != nil { + return err + } + _, fromKey, err := client.ParseMapSource(mp.Name) + if err != nil { + return err + } + to, toKey, err := client.ParseMapSource(mp.MapTo) + if err != nil { + return err + } + + body, err := c.IngressRequest.GetBody() + if err != nil { + return err + } + rawBody, err := ioutil.ReadAll(body) + if err != nil { + return err + } + mapBody := map[string]interface{}{} + json.Unmarshal(rawBody, &mapBody) + val, err := client.GetMapValue(mapBody, fromKey) + + setTarget(target, to, strings.Join(toKey, constant.Dot), val) + return nil +} + +func validateTarget(target interface{}) (*requestParams, error) { + rv := reflect.ValueOf(target) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return nil, errors.New("Target params must be a non-nil pointer") + } + val, ok := target.(*requestParams) + if !ok { + return nil, errors.New("Target params must be a requestParams pointer") + } + return val, nil +} + +func setTarget(target *requestParams, to string, key string, val interface{}) error { + valType := reflect.TypeOf(val) + if (to == constant.Headers || to == constant.QueryStrings) && valType.Kind() != reflect.String { + return errors.Errorf("%s only accepts string", to) + } + switch to { + case constant.Headers: + target.Header.Set(key, val.(string)) + case constant.QueryStrings: + target.Query.Set(key, val.(string)) + case constant.RequestBody: + rawBody, err := ioutil.ReadAll(target.Body) + if err != nil { + return errors.New("Raw body parse failed") + } + mapBody := map[string]interface{}{} + json.Unmarshal(rawBody, &mapBody) + + setMapWithPath(mapBody, key, val) + rawBody, err = json.Marshal(mapBody) + if err != nil { + return errors.New("Stringify map to body failed") + } + target.Body = ioutil.NopCloser(bytes.NewReader(rawBody)) + default: + return errors.Errorf("Mapping target to %s does not support", to) + } + return nil +} + +// setMapWithPath set the value to the target map. If the origin targetMap has +// {"abc": "cde": {"f":1, "g":2}} and the path is abc, value is "123", then the +// targetMap will be updated to {"abc", "123"} +func setMapWithPath(targetMap map[string]interface{}, path string, val interface{}) map[string]interface{} { + keys := strings.Split(path, constant.Dot) + + _, ok := targetMap[keys[0]] + if len(keys) == 1 { + targetMap[keys[0]] = val + return targetMap + } + if !ok && len(keys) != 1 { + targetMap[keys[0]] = make(map[string]interface{}) + targetMap[keys[0]] = setMapWithPath(targetMap[keys[0]].(map[string]interface{}), strings.Join(keys[1:len(keys)], constant.Dot), val) + } + return targetMap +} diff --git a/pkg/client/http/mapper_test.go b/pkg/client/http/mapper_test.go new file mode 100644 index 000000000..1bc833cf5 --- /dev/null +++ b/pkg/client/http/mapper_test.go @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "bytes" + "context" + "io/ioutil" + "net/http" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/dubbogo/dubbo-go-proxy/pkg/client" + "github.com/dubbogo/dubbo-go-proxy/pkg/common/constant" + "github.com/dubbogo/dubbo-go-proxy/pkg/common/mock" + "github.com/dubbogo/dubbo-go-proxy/pkg/config" +) + +func TestQueryMapper(t *testing.T) { + qs := queryStringsMapper{} + r, _ := http.NewRequest("GET", "/mock/test?id=12345&age=19&name=joe&nickName=trump", bytes.NewReader([]byte(""))) + api := mock.GetMockAPI(config.MethodGet, "/mock/test") + api.IntegrationRequest.MappingParams = []config.MappingParam{ + { + Name: "queryStrings.id", + MapTo: "headers.Id", + }, + { + Name: "queryStrings.name", + MapTo: "queryStrings.name", + }, + { + Name: "queryStrings.age", + MapTo: "requestBody.age", + }, + { + Name: "queryStrings.nickName", + MapTo: "requestBody.nickName", + }, + } + req := client.NewReq(context.TODO(), r, api) + + target := newRequestParams() + err := qs.Map(api.IntegrationRequest.MappingParams[0], *req, target) + assert.Nil(t, err) + assert.Equal(t, target.Header.Get("Id"), "12345") + + err = qs.Map(api.IntegrationRequest.MappingParams[1], *req, target) + assert.Nil(t, err) + assert.Equal(t, target.Query.Get("name"), "joe") + + err = qs.Map(api.IntegrationRequest.MappingParams[2], *req, target) + assert.Nil(t, err) + err = qs.Map(api.IntegrationRequest.MappingParams[3], *req, target) + assert.Nil(t, err) + rawBody, _ := ioutil.ReadAll(target.Body) + assert.Equal(t, string(rawBody), "{\"age\":\"19\",\"nickName\":\"trump\"}") + + err = qs.Map(config.MappingParam{Name: "queryStrings.doesNotExistField", MapTo: "queryStrings.whatever"}, *req, target) + assert.EqualError(t, err, "doesNotExistField in query parameters not found") +} + +func TestHeaderMapper(t *testing.T) { + hm := headerMapper{} + r, _ := http.NewRequest("GET", "/mock/test?id=12345&age=19&name=joe&nickName=trump", bytes.NewReader([]byte(""))) + r.Header.Set("Auth", "xxxx12345xxx") + r.Header.Set("Token", "ttttt12345ttt") + r.Header.Set("Origin-Passcode", "whoseyourdaddy") + r.Header.Set("Pokemon-Name", "Pika") + api := mock.GetMockAPI(config.MethodGet, "/mock/test") + api.IntegrationRequest.MappingParams = []config.MappingParam{ + { + Name: "headers.Auth", + MapTo: "headers.Auth", + }, + { + Name: "headers.Token", + MapTo: "headers.Token", + }, + { + Name: "headers.Origin-Passcode", + MapTo: "queryStrings.originPasscode", + }, + { + Name: "headers.Pokemon-Name", + MapTo: "requestBody.pokeMonName", + }, + } + req := client.NewReq(context.TODO(), r, api) + + target := newRequestParams() + err := hm.Map(api.IntegrationRequest.MappingParams[0], *req, target) + assert.Nil(t, err) + assert.Equal(t, target.Header.Get("Auth"), "xxxx12345xxx") + err = hm.Map(api.IntegrationRequest.MappingParams[1], *req, target) + assert.Nil(t, err) + assert.Equal(t, target.Header.Get("Token"), "ttttt12345ttt") + + err = hm.Map(api.IntegrationRequest.MappingParams[2], *req, target) + assert.Nil(t, err) + assert.Equal(t, target.Query.Get("originPasscode"), "whoseyourdaddy") + + err = hm.Map(api.IntegrationRequest.MappingParams[3], *req, target) + assert.Nil(t, err) + rawBody, err := ioutil.ReadAll(target.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"pokeMonName\":\"Pika\"}") + + err = hm.Map(config.MappingParam{Name: "headers.doesNotExistField", MapTo: "headers.whatever"}, *req, target) + assert.EqualError(t, err, "Header doesNotExistField not found") +} + +func TestBodyMapper(t *testing.T) { + bm := bodyMapper{} + r, _ := http.NewRequest("POST", "/mock/test", bytes.NewReader([]byte("{\"id\":\"12345\",\"age\":\"19\",\"testStruct\":{\"name\":\"mock\",\"test\":\"happy\",\"nickName\":\"trump\"}}"))) + api := mock.GetMockAPI(config.MethodGet, "/mock/test") + api.IntegrationRequest.MappingParams = []config.MappingParam{ + { + Name: "requestBody.id", + MapTo: "headers.Id", + }, + { + Name: "requestBody.age", + MapTo: "requestBody.age", + }, + { + Name: "requestBody.testStruct", + MapTo: "requestBody.testStruct", + }, + { + Name: "requestBody.testStruct.nickName", + MapTo: "requestBody.nickName", + }, + } + req := client.NewReq(context.TODO(), r, api) + + target := newRequestParams() + err := bm.Map(api.IntegrationRequest.MappingParams[0], *req, target) + assert.Nil(t, err) + assert.Equal(t, target.Header.Get("Id"), "12345") + + target = newRequestParams() + err = bm.Map(api.IntegrationRequest.MappingParams[1], *req, target) + assert.Nil(t, err) + + err = bm.Map(api.IntegrationRequest.MappingParams[2], *req, target) + assert.Nil(t, err) + + err = bm.Map(api.IntegrationRequest.MappingParams[3], *req, target) + assert.Nil(t, err) + rawBody, err := ioutil.ReadAll(target.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"age\":\"19\",\"nickName\":\"trump\",\"testStruct\":{\"name\":\"mock\",\"nickName\":\"trump\",\"test\":\"happy\"}}") +} + +func TestSetTarget(t *testing.T) { + emptyRequestParams := newRequestParams() + err := setTarget(emptyRequestParams, constant.Headers, "Auth", "1234565") + assert.Nil(t, err) + assert.Equal(t, emptyRequestParams.Header.Get("Auth"), "1234565") + err = setTarget(emptyRequestParams, constant.Headers, "Auth", 1234565) + assert.EqualError(t, err, "headers only accepts string") + + err = setTarget(emptyRequestParams, constant.QueryStrings, "id", "123") + assert.Nil(t, err) + assert.Equal(t, emptyRequestParams.Query.Get("id"), "123") + err = setTarget(emptyRequestParams, constant.QueryStrings, "id", 123) + assert.EqualError(t, err, "queryStrings only accepts string") + + err = setTarget(emptyRequestParams, constant.RequestBody, "testStruct", map[string]interface{}{"test": "happy", "name": "mock"}) + assert.Nil(t, err) + rawBody, err := ioutil.ReadAll(emptyRequestParams.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"testStruct\":{\"name\":\"mock\",\"test\":\"happy\"}}") + + err = setTarget(emptyRequestParams, constant.RequestBody, "testStruct.secondLayer", map[string]interface{}{"test": "happy", "name": "mock"}) + assert.Nil(t, err) + rawBody, err = ioutil.ReadAll(emptyRequestParams.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"testStruct\":{\"secondLayer\":{\"name\":\"mock\",\"test\":\"happy\"}}}") + + err = setTarget(emptyRequestParams, constant.RequestBody, "testStruct.secondLayer.thirdLayer", map[string]interface{}{"test": "happy", "name": "mock"}) + assert.Nil(t, err) + rawBody, err = ioutil.ReadAll(emptyRequestParams.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"testStruct\":{\"secondLayer\":{\"thirdLayer\":{\"name\":\"mock\",\"test\":\"happy\"}}}}") + + nonEmptyRequestParams := newRequestParams() + nonEmptyRequestParams.Body = ioutil.NopCloser(bytes.NewReader([]byte("{\"testStruct\":\"abcde\"}"))) + err = setTarget(nonEmptyRequestParams, constant.RequestBody, "testStruct", map[string]interface{}{"test": "happy", "name": "mock"}) + assert.Nil(t, err) + rawBody, err = ioutil.ReadAll(nonEmptyRequestParams.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"testStruct\":{\"name\":\"mock\",\"test\":\"happy\"}}") + + nonEmptyRequestParams = newRequestParams() + nonEmptyRequestParams.Body = ioutil.NopCloser(bytes.NewReader([]byte("{\"otherStructure\":\"abcde\"}"))) + err = setTarget(nonEmptyRequestParams, constant.RequestBody, "testStruct", map[string]interface{}{"test": "happy", "name": "mock"}) + assert.Nil(t, err) + rawBody, err = ioutil.ReadAll(nonEmptyRequestParams.Body) + assert.Nil(t, err) + assert.Equal(t, string(rawBody), "{\"otherStructure\":\"abcde\",\"testStruct\":{\"name\":\"mock\",\"test\":\"happy\"}}") +} diff --git a/pkg/client/mapper.go b/pkg/client/mapper.go index 3acd5f847..672686cd3 100644 --- a/pkg/client/mapper.go +++ b/pkg/client/mapper.go @@ -28,6 +28,7 @@ import ( import ( "github.com/dubbogo/dubbo-go-proxy/pkg/config" + "reflect" ) // ParamMapper defines the interface about how to map the params in the inbound request. @@ -38,10 +39,26 @@ type ParamMapper interface { // ParseMapSource parses the source parameter config in the mappingParams // the source parameter in config could be queryStrings.*, headers.*, requestBody.* func ParseMapSource(source string) (from string, params []string, err error) { - reg := regexp.MustCompile(`^([queryStrings|headers|requestBody][\w|\d]+)\.([\w|\d|\.]+)$`) + reg := regexp.MustCompile(`^([queryStrings|headers|requestBody][\w|\d]+)\.([\w|\d|\.|\-]+)$`) if !reg.MatchString(source) { return "", nil, errors.New("Parameter mapping config incorrect. Please fix it") } ps := reg.FindStringSubmatch(source) return ps[1], strings.Split(ps[2], "."), nil } + +// GetMapValue return the value from map base on the path +func GetMapValue(sourceMap map[string]interface{}, keys []string) (interface{}, error) { + _, ok := sourceMap[keys[0]] + if !ok { + return nil, errors.Errorf("%s does not exist in request body", keys[0]) + } + rvalue := reflect.ValueOf(sourceMap[keys[0]]) + if ok && len(keys) == 1 { + return rvalue.Interface(), nil + } + if rvalue.Type().Kind() != reflect.Map { + return rvalue.Interface(), nil + } + return GetMapValue(sourceMap[keys[0]].(map[string]interface{}), keys[1:]) +} diff --git a/pkg/client/mapper_test.go b/pkg/client/mapper_test.go new file mode 100644 index 000000000..7c06f54e9 --- /dev/null +++ b/pkg/client/mapper_test.go @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package client + +import ( + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +func TestParseMapSource(t *testing.T) { + from, key, err := ParseMapSource("queryStrings.id") + assert.Nil(t, err) + assert.Equal(t, from, "queryStrings") + assert.Equal(t, key[0], "id") + + from, key, err = ParseMapSource("headers.id") + assert.Nil(t, err) + assert.Equal(t, from, "headers") + assert.Equal(t, key[0], "id") + + from, key, err = ParseMapSource("requestBody.user.id") + assert.Nil(t, err) + assert.Equal(t, from, "requestBody") + assert.Equal(t, key[0], "user") + assert.Equal(t, key[1], "id") + + from, key, err = ParseMapSource("what.user.id") + assert.EqualError(t, err, "Parameter mapping config incorrect. Please fix it") + from, key, err = ParseMapSource("requestBody.*userid") + assert.EqualError(t, err, "Parameter mapping config incorrect. Please fix it") +} + +func TestGetMapValue(t *testing.T) { + testMap := map[string]interface{}{ + "Test": "test", + "structure": map[string]interface{}{ + "name": "joe", + "age": 77, + }, + } + val, err := GetMapValue(testMap, []string{"Test"}) + assert.Nil(t, err) + assert.Equal(t, val, "test") + val, err = GetMapValue(testMap, []string{"test"}) + assert.Nil(t, val) + assert.EqualError(t, err, "test does not exist in request body") + val, err = GetMapValue(testMap, []string{"structure"}) + assert.Nil(t, err) + assert.Equal(t, val, testMap["structure"]) + val, err = GetMapValue(testMap, []string{"structure", "name"}) + assert.Nil(t, err) + assert.Equal(t, val, "joe") +} diff --git a/pkg/common/constant/url.go b/pkg/common/constant/url.go index dd396d53a..4d35b8afc 100644 --- a/pkg/common/constant/url.go +++ b/pkg/common/constant/url.go @@ -28,3 +28,14 @@ const ( // RetriesKey retry times RetriesKey = "retries" ) + +const ( + // RequestBody name of api config mapping from/to + RequestBody = "requestBody" + // QueryStrings name of api config mapping from/to + QueryStrings = "queryStrings" + // Headers name of api config mapping from/to + Headers = "headers" + // Dot defines the . which will be used to present the path to specific field in the body + Dot = "." +)