From f3e4d091ea9ca88c61b9bb4f666ab89d99817101 Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 18 Jun 2021 15:52:31 +0800 Subject: [PATCH] fix(ratelimit): improve rate-limiter ip detection --- resolver/resolver_test.go | 99 +++++++++++++++++++++++++++++++++++++++ resolver/validator.go | 30 +++++++++++- 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go index 190a2584..b3aea5e8 100644 --- a/resolver/resolver_test.go +++ b/resolver/resolver_test.go @@ -507,4 +507,103 @@ var _ = Describe("Resolver", func() { Expect(resp).ShouldNot(Equal(jsonrpc.Response{})) }) + + It("should rate limit", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, validator, _ := init(ctx) + defer cleanup() + + params := testutils.MockBurnParamSubmitTxV0BTC() + paramsJSON, err := json.Marshal(params) + Expect(err).ShouldNot(HaveOccurred()) + Expect(params).ShouldNot(Equal([]byte{})) + + innerCtx, innerCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer innerCancel() + + ipString := "127.0.0.1" + + httpRequest := &http.Request{ + Header: map[string][]string{}, + RemoteAddr: ipString, + } + + req, resp := validator.ValidateRequest(innerCtx, httpRequest, jsonrpc.Request{ + Version: "2.0", + ID: nil, + Method: jsonrpc.MethodSubmitTx, + Params: paramsJSON, + }) + // Response will only exist for errors + Expect(resp).Should(Equal(jsonrpc.Response{})) + Expect((req).(*jsonrpc.ParamsSubmitTx).Tx.Hash).ShouldNot(BeEmpty()) + + Eventually(func() jsonrpc.Response { + _, resp := validator.ValidateRequest(innerCtx, httpRequest, jsonrpc.Request{ + Version: "2.0", + ID: nil, + Method: jsonrpc.MethodSubmitTx, + Params: paramsJSON, + }) + return resp + }).Should(Equal( + jsonrpc.NewResponse(nil, nil, &jsonrpc.Error{ + Code: jsonrpc.ErrorCodeInvalidRequest, + Message: fmt.Sprintf("rate limit exceeded for %v", ipString), + }), + )) + + ipString = "8.8.8.8" + httpRequest.Header.Add("x-forwarded-for", ipString) + Eventually(func() jsonrpc.Response { + _, resp := validator.ValidateRequest(innerCtx, httpRequest, jsonrpc.Request{ + Version: "2.0", + ID: nil, + Method: jsonrpc.MethodSubmitTx, + Params: paramsJSON, + }) + return resp + }).Should(Equal( + jsonrpc.NewResponse(nil, nil, &jsonrpc.Error{ + Code: jsonrpc.ErrorCodeInvalidRequest, + Message: fmt.Sprintf("rate limit exceeded for %v", ipString), + }), + )) + + ipString = "1.1.1.1,9.9.9.9" + httpRequest.Header.Set("x-forwarded-for", ipString) + Eventually(func() jsonrpc.Response { + _, resp := validator.ValidateRequest(innerCtx, httpRequest, jsonrpc.Request{ + Version: "2.0", + ID: nil, + Method: jsonrpc.MethodSubmitTx, + Params: paramsJSON, + }) + return resp + }).Should(Equal( + jsonrpc.NewResponse(nil, nil, &jsonrpc.Error{ + Code: jsonrpc.ErrorCodeInvalidRequest, + Message: fmt.Sprintf("rate limit exceeded for %v", "9.9.9.9"), + }), + )) + + ipString = "1.1.1.1,9.9.9.9,,," + httpRequest.Header.Set("x-forwarded-for", ipString) + Eventually(func() jsonrpc.Response { + _, resp := validator.ValidateRequest(innerCtx, httpRequest, jsonrpc.Request{ + Version: "2.0", + ID: nil, + Method: jsonrpc.MethodSubmitTx, + Params: paramsJSON, + }) + return resp + }).Should(Equal( + jsonrpc.NewResponse(nil, nil, &jsonrpc.Error{ + Code: jsonrpc.ErrorCodeInvalidRequest, + Message: fmt.Sprintf("could not determine ip for %v", ipString), + }), + )) + }) }) diff --git a/resolver/validator.go b/resolver/validator.go index a6de0902..6ccb2221 100644 --- a/resolver/validator.go +++ b/resolver/validator.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/http" + "strings" "github.com/renproject/darknode/binding" "github.com/renproject/darknode/jsonrpc" @@ -37,17 +38,44 @@ func NewValidator(bindings binding.Bindings, pubkey *id.PubKey, store v0.CompatS // The validator usually checks if the params are in the correct shape for a given method // We override the checker for certain methods here to cast invalid v0 params into v1 versions func (validator *LightnodeValidator) ValidateRequest(ctx context.Context, r *http.Request, req jsonrpc.Request) (interface{}, jsonrpc.Response) { + // We rate limit in the validator, as it is the earliest entry point we can hook into + // for range ipString := r.Header.Get("x-forwarded-for") if ipString == "" { ipString = r.RemoteAddr + } else if ipStrings := strings.Split(ipString, ","); len(ipStrings) > 0 { + ipString = ipStrings[len(ipStrings)-1] + // if there is a trailling comma, or the x-forwarded-for header is malformed, + // skip parsing + if ipString == "" { + return nil, jsonrpc.NewResponse(req.ID, nil, &jsonrpc.Error{ + Code: jsonrpc.ErrorCodeInvalidRequest, + Message: fmt.Sprintf("could not determine ip for %v", strings.Join(ipStrings, ",")), + }) + } } ip := net.ParseIP(ipString) + // If we fail to parse a "plain" ip, we check if it is in host:port format + // This can't be done in an easy split manner due to ipv6. + // We also skip requiring an ip if we haven't picked up a string yet to + // allow for testing, as we should always have a value from r.RemoteAddr + // in an actual server + if ip == nil && ipString != "" { + ip2, _, err := net.SplitHostPort(ipString) + ip = net.ParseIP(ip2) + if err != nil { + return nil, jsonrpc.NewResponse(req.ID, nil, &jsonrpc.Error{ + Code: jsonrpc.ErrorCodeInvalidRequest, + Message: fmt.Sprintf("could not determine ip for %v", ipString), + }) + } + } if !(validator.limiter.Allow(req.Method, net.IP(ip))) { validator.logger.Warn("Rate limit exceeded for ip:", ipString) return nil, jsonrpc.NewResponse(req.ID, nil, &jsonrpc.Error{ Code: jsonrpc.ErrorCodeInvalidRequest, - Message: fmt.Sprintf("rate limit exceeded"), + Message: fmt.Sprintf("rate limit exceeded for %v", ipString), }) }