diff --git a/cmd/main.go b/cmd/main.go index b503976e..83f2baee 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -115,7 +115,7 @@ func run(ctx context.Context, logger *logrus.Logger) error { logger.Info("Try to identify WAF solution") - name, vendor, err := detector.DetectWAF(ctx) + name, vendor, checkFunc, err := detector.DetectWAF(ctx) if err != nil { return errors.Wrap(err, "couldn't detect") } @@ -126,6 +126,7 @@ func run(ctx context.Context, logger *logrus.Logger) error { "vendor": vendor, }).Info("WAF was identified. Force enabling `--followCookies' and `--renewSession' options") + cfg.CheckBlockFunc = checkFunc cfg.FollowCookies = true cfg.RenewSession = true cfg.WAFName = fmt.Sprintf("%s (%s)", name, vendor) diff --git a/internal/config/config.go b/internal/config/config.go index 389dd61f..311336c1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,5 +1,7 @@ package config +import "github.com/wallarm/gotestwaf/internal/scanner/detectors" + type Config struct { URL string `mapstructure:"url"` WebSocketURL string `mapstructure:"wsURL"` @@ -37,4 +39,6 @@ type Config struct { AddHeader string `mapstructure:"addHeader"` AddDebugHeader bool `mapstructure:"addDebugHeader"` OpenAPIFile string `mapstructure:"openapiFile"` + + CheckBlockFunc detectors.Check } diff --git a/internal/scanner/detector.go b/internal/scanner/detector.go index 7439ca1b..faafefda 100644 --- a/internal/scanner/detector.go +++ b/internal/scanner/detector.go @@ -25,19 +25,28 @@ const ( ) type WAFDetector struct { - client *http.Client - headers map[string]string - hostHeader string - target string + clientSettings *ClientSettings + headers map[string]string + hostHeader string + target string +} + +type ClientSettings struct { + dnsResolver *dnscache.Resolver + insecureSkipVerify bool + idleConnTimeout time.Duration + maxIdleConns int + maxIdleConnsPerHost int + proxyURL *url.URL } func NewDetector(cfg *config.Config, dnsResolver *dnscache.Resolver) (*WAFDetector, error) { - tr := &http.Transport{ - DialContext: dnscache.DialFunc(dnsResolver, nil), - TLSClientConfig: &tls.Config{InsecureSkipVerify: !cfg.TLSVerify}, - IdleConnTimeout: time.Duration(cfg.IdleConnTimeout) * time.Second, - MaxIdleConns: cfg.MaxIdleConns, - MaxIdleConnsPerHost: cfg.MaxIdleConns, // net.http hardcodes DefaultMaxIdleConnsPerHost to 2! + clientSettings := &ClientSettings{ + dnsResolver: dnsResolver, + insecureSkipVerify: !cfg.TLSVerify, + idleConnTimeout: time.Duration(cfg.IdleConnTimeout) * time.Second, + maxIdleConns: cfg.MaxIdleConns, + maxIdleConnsPerHost: cfg.MaxIdleConns, } if cfg.Proxy != "" { @@ -46,17 +55,7 @@ func NewDetector(cfg *config.Config, dnsResolver *dnscache.Resolver) (*WAFDetect return nil, errors.Wrap(err, "couldn't parse proxy URL") } - tr.Proxy = http.ProxyURL(proxyURL) - } - - jar, err := cookiejar.New(nil) - if err != nil { - return nil, errors.Wrap(err, "couldn't create cookie jar") - } - - client := &http.Client{ - Transport: tr, - Jar: jar, + clientSettings.proxyURL = proxyURL } target, err := url.Parse(cfg.URL) @@ -73,20 +72,71 @@ func NewDetector(cfg *config.Config, dnsResolver *dnscache.Resolver) (*WAFDetect } return &WAFDetector{ - client: client, - headers: configuredHeaders, - hostHeader: configuredHeaders["Host"], - target: GetTargetURLStr(target), + clientSettings: clientSettings, + headers: configuredHeaders, + hostHeader: configuredHeaders["Host"], + target: GetTargetURLStr(target), }, nil } -// doRequest sends HTTP-request with malicious payload to trigger WAF. +func (w *WAFDetector) getHttpClient() (*http.Client, error) { + tr := &http.Transport{ + DialContext: dnscache.DialFunc(w.clientSettings.dnsResolver, nil), + TLSClientConfig: &tls.Config{InsecureSkipVerify: w.clientSettings.insecureSkipVerify}, + IdleConnTimeout: w.clientSettings.idleConnTimeout, + MaxIdleConns: w.clientSettings.maxIdleConns, + MaxIdleConnsPerHost: w.clientSettings.maxIdleConns, // net.http hardcodes DefaultMaxIdleConnsPerHost to 2! + } + + if w.clientSettings.proxyURL != nil { + tr.Proxy = http.ProxyURL(w.clientSettings.proxyURL) + } + + jar, err := cookiejar.New(nil) + if err != nil { + return nil, errors.Wrap(err, "couldn't create cookie jar") + } + + client := &http.Client{ + Transport: tr, + Jar: jar, + } + + return client, nil +} + +// doRequest sends HTTP-request without malicious payload to trigger WAF. func (w *WAFDetector) doRequest(ctx context.Context) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, w.target, nil) if err != nil { return nil, errors.Wrap(err, "couldn't create request") } + for header, value := range w.headers { + req.Header.Set(header, value) + } + req.Host = w.hostHeader + + client, err := w.getHttpClient() + if err != nil { + return nil, errors.Wrap(err, "couldn't create HTTP client") + } + + resp, err := client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to sent request") + } + + return resp, nil +} + +// doMaliciousRequest sends HTTP-request with malicious payload to trigger WAF. +func (w *WAFDetector) doMaliciousRequest(ctx context.Context) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, w.target, nil) + if err != nil { + return nil, errors.Wrap(err, "couldn't create request") + } + queryParams := req.URL.Query() queryParams.Add("a", xssPayload) queryParams.Add("b", sqliPayload) @@ -101,7 +151,12 @@ func (w *WAFDetector) doRequest(ctx context.Context) (*http.Response, error) { } req.Host = w.hostHeader - resp, err := w.client.Do(req) + client, err := w.getHttpClient() + if err != nil { + return nil, errors.Wrap(err, "couldn't create HTTP client") + } + + resp, err := client.Do(req) if err != nil { return nil, errors.Wrap(err, "failed to sent request") } @@ -111,19 +166,31 @@ func (w *WAFDetector) doRequest(ctx context.Context) (*http.Response, error) { // DetectWAF performs WAF identification. Returns WAF name and vendor after // the first positive match. -func (w *WAFDetector) DetectWAF(ctx context.Context) (name, vendor string, err error) { +func (w *WAFDetector) DetectWAF(ctx context.Context) (name, vendor string, checkFunc detectors.Check, err error) { resp, err := w.doRequest(ctx) if err != nil { - return "", "", errors.Wrap(err, "couldn't identify WAF") + return "", "", nil, errors.Wrap(err, "couldn't perform request without attack") } defer resp.Body.Close() + respToAttack, err := w.doMaliciousRequest(ctx) + if err != nil { + return "", "", nil, errors.Wrap(err, "couldn't perform request with attack") + } + + defer respToAttack.Body.Close() + + resps := &detectors.Responses{ + Resp: resp, + RespToAttack: respToAttack, + } + for _, d := range detectors.Detectors { - if d.IsWAF(resp) { - return d.WAFName, d.Vendor, nil + if d.IsWAF(resps) { + return d.WAFName, d.Vendor, d.Check, nil } } - return "", "", nil + return "", "", nil, nil } diff --git a/internal/scanner/detectors/akamai.go b/internal/scanner/detectors/akamai.go index d3d2998b..b9d7b3af 100644 --- a/internal/scanner/detectors/akamai.go +++ b/internal/scanner/detectors/akamai.go @@ -6,9 +6,10 @@ func KonaSiteDefender() *Detector { Vendor: "Akamai", } - d.Checks = []Check{ - CheckHeader("Server", "AkamaiGHost"), - } + d.Check = Or( + CheckHeader("Server", "AkamaiGHost", false), + CheckHeader("Server", "AkamaiGHost", true), + ) return d } diff --git a/internal/scanner/detectors/checks.go b/internal/scanner/detectors/checks.go index c1e23e08..ff92a6ef 100644 --- a/internal/scanner/detectors/checks.go +++ b/internal/scanner/detectors/checks.go @@ -5,14 +5,30 @@ import ( "io" "net/http" "regexp" + "strings" ) +type Responses struct { + Resp *http.Response + RespToAttack *http.Response +} + // Check performs some check on the response with a fixed condition. -type Check func(resp *http.Response) bool +type Check func(resps *Responses) bool // CheckStatusCode compare response status code with given value. -func CheckStatusCode(status int) Check { - f := func(resp *http.Response) bool { +// Default value for attack parameter is true. +func CheckStatusCode(status int, attack bool) Check { + f := func(resps *Responses) bool { + resp := resps.Resp + if attack { + resp = resps.RespToAttack + } + + if resp == nil { + return false + } + if resp.StatusCode == status { return true } @@ -23,11 +39,49 @@ func CheckStatusCode(status int) Check { return f } +// CheckReason match status reason value with regex. +// Default value for attack parameter is true. +func CheckReason(regex string, attack bool) Check { + re := regexp.MustCompile(regex) + + f := func(resps *Responses) bool { + resp := resps.Resp + if attack { + resp = resps.RespToAttack + } + + if resp == nil { + return false + } + + reasonIndex := strings.Index(resp.Status, " ") + reason := resp.Status[reasonIndex+1:] + + if re.MatchString(reason) { + return true + } + + return false + } + + return f +} + // CheckHeader match header value with regex. -func CheckHeader(header, regex string) Check { +// Default value for attack parameter is false. +func CheckHeader(header, regex string, attack bool) Check { re := regexp.MustCompile(regex) - f := func(resp *http.Response) bool { + f := func(resps *Responses) bool { + resp := resps.Resp + if attack { + resp = resps.RespToAttack + } + + if resp == nil { + return false + } + values := resp.Header.Values(header) if values == nil { return false @@ -46,15 +100,26 @@ func CheckHeader(header, regex string) Check { } // CheckCookie match Set-Cookie header values with regex. -func CheckCookie(regex string) Check { - return CheckHeader("Set-Cookie", regex) +// Default value for attack parameter is false. +func CheckCookie(regex string, attack bool) Check { + return CheckHeader("Set-Cookie", regex, attack) } // CheckContent match body value with regex. -func CheckContent(regex string) Check { +// Default value for attack parameter is true. +func CheckContent(regex string, attack bool) Check { re := regexp.MustCompile(regex) - f := func(resp *http.Response) bool { + f := func(resps *Responses) bool { + resp := resps.Resp + if attack { + resp = resps.RespToAttack + } + + if resp == nil { + return false + } + body, err := io.ReadAll(resp.Body) if err != nil { return false @@ -73,3 +138,35 @@ func CheckContent(regex string) Check { return f } + +// And combines the checks with AND logic, +// so each test must be true to return true. +func And(checks ...Check) Check { + f := func(resps *Responses) bool { + for _, check := range checks { + if !check(resps) { + return false + } + } + + return true + } + + return f +} + +// Or combines the checks with OR logic, +// so at least one test must be true to return true. +func Or(checks ...Check) Check { + f := func(resps *Responses) bool { + for _, check := range checks { + if check(resps) { + return true + } + } + + return false + } + + return f +} diff --git a/internal/scanner/detectors/detectors.go b/internal/scanner/detectors/detectors.go index 373be9c0..9f2a589b 100644 --- a/internal/scanner/detectors/detectors.go +++ b/internal/scanner/detectors/detectors.go @@ -1,14 +1,12 @@ package detectors -import "net/http" - // Detector contains names of WAF solution and vendor, and checks to detect that // solution by response. type Detector struct { WAFName string Vendor string - Checks []Check + Check Check } func (d *Detector) GetWAFName() string { @@ -19,14 +17,8 @@ func (d *Detector) GetVendor() string { return d.Vendor } -func (d *Detector) IsWAF(resp *http.Response) bool { - for _, check := range d.Checks { - if check(resp) { - return true - } - } - - return false +func (d *Detector) IsWAF(resps *Responses) bool { + return d.Check(resps) } // Detectors is the list of all available WAF detectors. The checks are performed @@ -38,4 +30,13 @@ var Detectors = []*Detector{ // Imperva Incapsula(), SecureSphere(), + + // F5 Networks + BigIPAppSecManager(), + BigIPLocalTrafficManager(), + BigIPApManager(), + FirePass(), + Trafficshield(), + + ModSecurity(), } diff --git a/internal/scanner/detectors/f5networks.go b/internal/scanner/detectors/f5networks.go new file mode 100644 index 00000000..1a2ee552 --- /dev/null +++ b/internal/scanner/detectors/f5networks.go @@ -0,0 +1,91 @@ +package detectors + +func BigIPAppSecManager() *Detector { + d := &Detector{ + WAFName: "BIG-IP AppSec Manager", + Vendor: "F5 Networks", + } + + d.Check = Or( + And( + CheckContent("the requested url was rejected", true), + CheckContent("please consult with your administrator", true), + ), + CheckCookie("^TS.+?", false), + CheckContent("Reference ID", true), + ) + + return d +} + +func BigIPLocalTrafficManager() *Detector { + d := &Detector{ + WAFName: "BIG-IP Local Traffic Manager", + Vendor: "F5 Networks", + } + + d.Check = Or( + CheckCookie("^bigipserver", false), + CheckHeader("X-Cnection", "close", true), + ) + + return d +} + +func BigIPApManager() *Detector { + d := &Detector{ + WAFName: "BIG-IP AP Manager", + Vendor: "F5 Networks", + } + + d.Check = Or( + And( + CheckCookie("^LastMRH_Session", false), + CheckCookie("^MRHSession", false), + ), + And( + CheckCookie("^MRHSession", false), + CheckHeader("Server", "Big([-_])?IP", true), + ), + Or( + CheckCookie("^F5_fullWT", false), + CheckCookie("^F5_HT_shrinked", false), + ), + ) + + return d +} + +func FirePass() *Detector { + d := &Detector{ + WAFName: "FirePass", + Vendor: "F5 Networks", + } + + d.Check = Or( + And( + CheckCookie("^VHOST", false), + CheckHeader("Location", `\/my\.logon\.php3`, false), + ), + And( + CheckCookie("^F5_fire.+?", false), + CheckCookie("^F5_passid_shrinked", false), + ), + ) + + return d +} + +func Trafficshield() *Detector { + d := &Detector{ + WAFName: "Trafficshield", + Vendor: "F5 Networks", + } + + d.Check = Or( + CheckCookie("^ASINFO=", false), + CheckHeader("Server", "F5-TrafficShield", false), + ) + + return d +} diff --git a/internal/scanner/detectors/imperva.go b/internal/scanner/detectors/imperva.go index a8f5abcb..9cdfc575 100644 --- a/internal/scanner/detectors/imperva.go +++ b/internal/scanner/detectors/imperva.go @@ -6,12 +6,12 @@ func SecureSphere() *Detector { Vendor: "Imperva Inc.", } - d.Checks = []Check{ - CheckContent("<(title|h2)>Error"), - CheckContent("The incident ID is"), - CheckContent("This page can't be displayed"), - CheckContent("Contact support for additional information"), - } + d.Check = And( + CheckContent("<(title|h2)>Error", true), + CheckContent("The incident ID is", true), + CheckContent("This page can't be displayed", true), + CheckContent("Contact support for additional information", true), + ) return d } @@ -22,13 +22,13 @@ func Incapsula() *Detector { Vendor: "Imperva Inc.", } - d.Checks = []Check{ - CheckCookie("^incap_ses.*?="), - CheckCookie("^visid_incap.*?="), - CheckContent("incapsula incident id"), - CheckContent("powered by incapsula"), - CheckContent("/_Incapsula_Resource"), - } + d.Check = Or( + CheckCookie("^incap_ses.*?=", false), + CheckCookie("^visid_incap.*?=", false), + CheckContent("incapsula incident id", true), + CheckContent("powered by incapsula", true), + CheckContent("/_Incapsula_Resource", true), + ) return d } diff --git a/internal/scanner/detectors/modsec.go b/internal/scanner/detectors/modsec.go new file mode 100644 index 00000000..912b774b --- /dev/null +++ b/internal/scanner/detectors/modsec.go @@ -0,0 +1,30 @@ +package detectors + +func ModSecurity() *Detector { + d := &Detector{ + WAFName: "ModSecurity", + Vendor: "OWASP", + } + + d.Check = Or( + Or( + CheckHeader("Server", "(mod_security|Mod_Security|NOYB)", false), + CheckContent("This error was generated by Mod.?Security", true), + CheckContent("rules of the mod.security.module", true), + CheckContent("mod.security.rules triggered", true), + CheckContent("Protected by Mod.?Security", true), + CheckContent(`/modsecurity[\-_]errorpage/`, true), + CheckContent("modsecurity iis", true), + ), + And( + CheckReason("ModSecurity Action", true), + CheckStatusCode(403, true), + ), + And( + CheckReason("ModSecurity Action", true), + CheckStatusCode(406, true), + ), + ) + + return d +} diff --git a/internal/scanner/http_client.go b/internal/scanner/http_client.go index 1df1ed9b..5b87043c 100644 --- a/internal/scanner/http_client.go +++ b/internal/scanner/http_client.go @@ -1,6 +1,7 @@ package scanner import ( + "bytes" "context" "crypto/tls" "io" @@ -107,15 +108,15 @@ func (c *HTTPClient) SendPayload( placeholderName string, placeholderConfig placeholder.PlaceholderConfig, testHeaderValue string, -) (responseMsgHeader string, responseBody string, statusCode int, err error) { +) (resp *http.Response, responseMsgHeader string, responseBody string, statusCode int, err error) { encodedPayload, err := encoder.Apply(encoderName, payload) if err != nil { - return "", "", 0, errors.Wrap(err, "encoding payload") + return nil, "", "", 0, errors.Wrap(err, "encoding payload") } req, err := placeholder.Apply(targetURL, encodedPayload, placeholderName, placeholderConfig) if err != nil { - return "", "", 0, errors.Wrap(err, "apply placeholder") + return nil, "", "", 0, errors.Wrap(err, "apply placeholder") } req = req.WithContext(ctx) @@ -143,7 +144,7 @@ func (c *HTTPClient) SendPayload( if c.followCookies && c.renewSession { cookies, err := c.getCookies(ctx, targetURL) if err != nil { - return "", "", 0, errors.Wrap(err, "couldn't get cookies for malicious request") + return nil, "", "", 0, errors.Wrap(err, "couldn't get cookies for malicious request") } for _, cookie := range cookies { @@ -151,28 +152,32 @@ func (c *HTTPClient) SendPayload( } } - resp, err := c.client.Do(req) + resp, err = c.client.Do(req) if err != nil { - return "", "", 0, errors.Wrap(err, "sending http request") + return nil, "", "", 0, errors.Wrap(err, "sending http request") } - defer resp.Body.Close() msgHeader, err := httputil.DumpResponse(resp, false) if err != nil { - return "", "", 0, errors.Wrap(err, "dumping http response") + return nil, "", "", 0, errors.Wrap(err, "dumping http response") } bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return "", "", 0, errors.Wrap(err, "reading response body") + return nil, "", "", 0, errors.Wrap(err, "reading response body") } + + // body reuse + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + statusCode = resp.StatusCode if c.followCookies && !c.renewSession && c.client.Jar != nil { c.client.Jar.SetCookies(req.URL, resp.Cookies()) } - return string(msgHeader), string(bodyBytes), statusCode, nil + return resp, string(msgHeader), string(bodyBytes), statusCode, nil } func (c *HTTPClient) SendRequest( @@ -181,6 +186,7 @@ func (c *HTTPClient) SendRequest( followCookiesOverride *bool, renewSessionOverride *bool, ) ( + resp *http.Response, respHeaders http.Header, responseMsgHeader string, body string, @@ -209,7 +215,7 @@ func (c *HTTPClient) SendRequest( if followCookies && renewSession { cookies, err := c.getCookies(req.Context(), GetTargetURLStr(req.URL)) if err != nil { - return nil, "", "", 0, errors.Wrap(err, "couldn't get cookies for malicious request") + return nil, nil, "", "", 0, errors.Wrap(err, "couldn't get cookies for malicious request") } for _, cookie := range cookies { @@ -217,28 +223,32 @@ func (c *HTTPClient) SendRequest( } } - resp, err := c.client.Do(req) + resp, err = c.client.Do(req) if err != nil { - return nil, "", "", 0, errors.Wrap(err, "sending http request") + return nil, nil, "", "", 0, errors.Wrap(err, "sending http request") } - defer resp.Body.Close() msgHeader, err := httputil.DumpResponse(resp, false) if err != nil { - return nil, "", "", 0, errors.Wrap(err, "dumping http response") + return nil, nil, "", "", 0, errors.Wrap(err, "dumping http response") } bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return nil, "", "", 0, errors.Wrap(err, "reading response body") + return nil, nil, "", "", 0, errors.Wrap(err, "reading response body") } + + // body reuse + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + statusCode = resp.StatusCode if followCookies && !renewSession && c.client.Jar != nil { c.client.Jar.SetCookies(req.URL, resp.Cookies()) } - return resp.Header, string(msgHeader), string(bodyBytes), statusCode, nil + return resp, resp.Header, string(msgHeader), string(bodyBytes), statusCode, nil } func (c *HTTPClient) getCookies(ctx context.Context, targetURL string) ([]*http.Cookie, error) { diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 7755c639..3381dacc 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -18,6 +18,8 @@ import ( "syscall" "time" + "github.com/wallarm/gotestwaf/internal/scanner/detectors" + "github.com/getkin/kin-openapi/openapi3filter" "github.com/getkin/kin-openapi/routers" "github.com/gorilla/websocket" @@ -117,11 +119,13 @@ func (s *Scanner) CheckIfJavaScriptRequired(ctx context.Context) (bool, error) { return &b } - _, _, body, _, err := s.httpClient.SendRequest(rawRequest, "", getRef(true), getRef(true)) + resp, _, _, body, _, err := s.httpClient.SendRequest(rawRequest, "", getRef(true), getRef(true)) if err != nil { return false, err } + defer resp.Body.Close() + for i := range jsChallengeErrorMsgs { if strings.Contains(body, jsChallengeErrorMsgs[i]) { return true, nil @@ -199,11 +203,11 @@ func (s *Scanner) WAFBlockCheck(ctx context.Context) error { // preCheck sends given payload during the pre-check stage. func (s *Scanner) preCheck(ctx context.Context, payload string) (blocked bool, statusCode int, err error) { - respMsgHeader, respBody, code, err := s.httpClient.SendPayload(ctx, s.cfg.URL, payload, "URL", "URLParam", nil, "") + resp, respMsgHeader, respBody, code, err := s.httpClient.SendPayload(ctx, s.cfg.URL, payload, "URL", "URLParam", nil, "") if err != nil { return false, 0, err } - blocked, err = s.checkBlocking(respMsgHeader, respBody, code) + blocked, _, err = s.checkBlockedOrPassed(resp, respMsgHeader, respBody, code) if err != nil { return false, 0, err } @@ -398,9 +402,20 @@ func (s *Scanner) Run(ctx context.Context) error { return nil } -// checkBlocking checks the response status-code or request body using -// a regular expression to determine if the request has been blocked. -func (s *Scanner) checkBlocking(responseMsgHeader, body string, statusCode int) (bool, error) { +// checkBlockedOrPassed checks the response status-code or request body using +// a regular expression to determine if the request has been blocked or passed. +func (s *Scanner) checkBlockedOrPassed( + resp *http.Response, + responseMsgHeader, + body string, + statusCode int, +) (blocked, passed bool, err error) { + if s.cfg.CheckBlockFunc != nil { + if s.cfg.CheckBlockFunc(&detectors.Responses{RespToAttack: resp}) { + return true, false, nil + } + } + if s.cfg.BlockRegex != "" { response := body if responseMsgHeader != "" { @@ -408,24 +423,12 @@ func (s *Scanner) checkBlocking(responseMsgHeader, body string, statusCode int) } if response != "" { - m, _ := regexp.MatchString(s.cfg.BlockRegex, response) + matched, _ := regexp.MatchString(s.cfg.BlockRegex, response) - return m, nil + blocked = matched } } - for _, code := range s.cfg.BlockStatusCodes { - if statusCode == code { - return true, nil - } - } - - return false, nil -} - -// checkPass checks the response status-code or request body using -// a regular expression to determine if the request has been passed. -func (s *Scanner) checkPass(responseMsgHeader, body string, statusCode int) (bool, error) { if s.cfg.PassRegex != "" { response := body if responseMsgHeader != "" { @@ -433,19 +436,25 @@ func (s *Scanner) checkPass(responseMsgHeader, body string, statusCode int) (boo } if response != "" { - m, _ := regexp.MatchString(s.cfg.PassRegex, response) + matched, _ := regexp.MatchString(s.cfg.PassRegex, response) - return m, nil + passed = matched + } + } + + for _, code := range s.cfg.BlockStatusCodes { + if statusCode == code { + blocked = true } } for _, code := range s.cfg.PassStatusCodes { if statusCode == code { - return true, nil + passed = true } } - return false, nil + return } // produceTests generates all combinations of payload, encoder, and placeholder @@ -508,6 +517,7 @@ func (s *Scanner) produceTests(ctx context.Context, n int) <-chan *testWork { // placeholder. func (s *Scanner) scanURL(ctx context.Context, w *testWork) error { var ( + resp *http.Response respHeaders http.Header respMsgHeader string respBody string @@ -527,16 +537,20 @@ func (s *Scanner) scanURL(ctx context.Context, w *testWork) error { respBody, statusCode, err = s.grpcConn.Send(newCtx, w.encoder, w.payload) - _, _, _, _, err = s.updateDB(ctx, w, nil, nil, nil, nil, nil, + _, _, _, _, err = s.updateDB(ctx, w, nil, nil, nil, nil, nil, nil, statusCode, nil, "", respBody, err, "", true) return err } if s.requestTemplates == nil { - respMsgHeader, respBody, statusCode, err = s.httpClient.SendPayload(ctx, s.cfg.URL, w.payload, w.encoder, w.placeholder.Name, w.placeholder.Config, w.debugHeaderValue) + resp, respMsgHeader, respBody, statusCode, err = s.httpClient.SendPayload(ctx, s.cfg.URL, w.payload, w.encoder, w.placeholder.Name, w.placeholder.Config, w.debugHeaderValue) - _, _, _, _, err = s.updateDB(ctx, w, nil, nil, nil, nil, nil, + if resp != nil { + defer resp.Body.Close() + } + + _, _, _, _, err = s.updateDB(ctx, w, nil, nil, nil, nil, nil, resp, statusCode, nil, respMsgHeader, respBody, err, "", false) return err @@ -561,13 +575,15 @@ func (s *Scanner) scanURL(ctx context.Context, w *testWork) error { return errors.Wrap(err, "create request from template") } - respHeaders, respMsgHeader, respBody, statusCode, err = s.httpClient.SendRequest(req, w.debugHeaderValue, nil, nil) + resp, respHeaders, respMsgHeader, respBody, statusCode, err = s.httpClient.SendRequest(req, w.debugHeaderValue, nil, nil) additionalInfo = fmt.Sprintf("%s %s", template.Method, template.Path) passedTest, blockedTest, unresolvedTest, failedTest, err = s.updateDB(ctx, w, passedTest, blockedTest, unresolvedTest, failedTest, - req, statusCode, respHeaders, respMsgHeader, respBody, err, additionalInfo, false) + req, resp, statusCode, respHeaders, respMsgHeader, respBody, err, additionalInfo, false) + + resp.Body.Close() s.db.AddToScannedPaths(template.Method, template.Path) @@ -588,6 +604,7 @@ func (s *Scanner) updateDB( unresolvedTest *db.Info, failedTest *db.Info, req *http.Request, + resp *http.Response, respStatusCode int, respHeaders http.Header, respMsgHeader string, @@ -644,17 +661,11 @@ func (s *Scanner) updateDB( if blockedByReset { blocked = true } else { - blocked, err = s.checkBlocking(respMsgHeader, respBody, respStatusCode) + blocked, passed, err = s.checkBlockedOrPassed(resp, respMsgHeader, respBody, respStatusCode) if err != nil { return nil, nil, nil, nil, errors.Wrap(err, "failed to check blocking") } - - passed, err = s.checkPass(respMsgHeader, respBody, respStatusCode) - if err != nil { - return nil, nil, nil, nil, - errors.Wrap(err, "failed to check passed or not") - } } if s.requestTemplates != nil && !isGRPC {