Skip to content

Commit

Permalink
Improve WAF detection (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
svkirillov committed May 8, 2024
1 parent 561e801 commit 4bf549f
Show file tree
Hide file tree
Showing 11 changed files with 435 additions and 122 deletions.
3 changes: 2 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func run(ctx context.Context, logger *logrus.Logger) error {

logger.Info("Try to identify WAF solution")

name, vendor, err := detector.DetectWAF(ctx)
name, vendor, checkFunc, err := detector.DetectWAF(ctx)
if err != nil {
return errors.Wrap(err, "couldn't detect")
}
Expand All @@ -126,6 +126,7 @@ func run(ctx context.Context, logger *logrus.Logger) error {
"vendor": vendor,
}).Info("WAF was identified. Force enabling `--followCookies' and `--renewSession' options")

cfg.CheckBlockFunc = checkFunc
cfg.FollowCookies = true
cfg.RenewSession = true
cfg.WAFName = fmt.Sprintf("%s (%s)", name, vendor)
Expand Down
4 changes: 4 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package config

import "github.com/wallarm/gotestwaf/internal/scanner/detectors"

type Config struct {
URL string `mapstructure:"url"`
WebSocketURL string `mapstructure:"wsURL"`
Expand Down Expand Up @@ -37,4 +39,6 @@ type Config struct {
AddHeader string `mapstructure:"addHeader"`
AddDebugHeader bool `mapstructure:"addDebugHeader"`
OpenAPIFile string `mapstructure:"openapiFile"`

CheckBlockFunc detectors.Check
}
131 changes: 99 additions & 32 deletions internal/scanner/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,28 @@ const (
)

type WAFDetector struct {
client *http.Client
headers map[string]string
hostHeader string
target string
clientSettings *ClientSettings
headers map[string]string
hostHeader string
target string
}

type ClientSettings struct {
dnsResolver *dnscache.Resolver
insecureSkipVerify bool
idleConnTimeout time.Duration
maxIdleConns int
maxIdleConnsPerHost int
proxyURL *url.URL
}

func NewDetector(cfg *config.Config, dnsResolver *dnscache.Resolver) (*WAFDetector, error) {
tr := &http.Transport{
DialContext: dnscache.DialFunc(dnsResolver, nil),
TLSClientConfig: &tls.Config{InsecureSkipVerify: !cfg.TLSVerify},
IdleConnTimeout: time.Duration(cfg.IdleConnTimeout) * time.Second,
MaxIdleConns: cfg.MaxIdleConns,
MaxIdleConnsPerHost: cfg.MaxIdleConns, // net.http hardcodes DefaultMaxIdleConnsPerHost to 2!
clientSettings := &ClientSettings{
dnsResolver: dnsResolver,
insecureSkipVerify: !cfg.TLSVerify,
idleConnTimeout: time.Duration(cfg.IdleConnTimeout) * time.Second,
maxIdleConns: cfg.MaxIdleConns,
maxIdleConnsPerHost: cfg.MaxIdleConns,
}

if cfg.Proxy != "" {
Expand All @@ -46,17 +55,7 @@ func NewDetector(cfg *config.Config, dnsResolver *dnscache.Resolver) (*WAFDetect
return nil, errors.Wrap(err, "couldn't parse proxy URL")
}

tr.Proxy = http.ProxyURL(proxyURL)
}

jar, err := cookiejar.New(nil)
if err != nil {
return nil, errors.Wrap(err, "couldn't create cookie jar")
}

client := &http.Client{
Transport: tr,
Jar: jar,
clientSettings.proxyURL = proxyURL
}

target, err := url.Parse(cfg.URL)
Expand All @@ -73,20 +72,71 @@ func NewDetector(cfg *config.Config, dnsResolver *dnscache.Resolver) (*WAFDetect
}

return &WAFDetector{
client: client,
headers: configuredHeaders,
hostHeader: configuredHeaders["Host"],
target: GetTargetURLStr(target),
clientSettings: clientSettings,
headers: configuredHeaders,
hostHeader: configuredHeaders["Host"],
target: GetTargetURLStr(target),
}, nil
}

// doRequest sends HTTP-request with malicious payload to trigger WAF.
func (w *WAFDetector) getHttpClient() (*http.Client, error) {
tr := &http.Transport{
DialContext: dnscache.DialFunc(w.clientSettings.dnsResolver, nil),
TLSClientConfig: &tls.Config{InsecureSkipVerify: w.clientSettings.insecureSkipVerify},
IdleConnTimeout: w.clientSettings.idleConnTimeout,
MaxIdleConns: w.clientSettings.maxIdleConns,
MaxIdleConnsPerHost: w.clientSettings.maxIdleConns, // net.http hardcodes DefaultMaxIdleConnsPerHost to 2!
}

if w.clientSettings.proxyURL != nil {
tr.Proxy = http.ProxyURL(w.clientSettings.proxyURL)
}

jar, err := cookiejar.New(nil)
if err != nil {
return nil, errors.Wrap(err, "couldn't create cookie jar")
}

client := &http.Client{
Transport: tr,
Jar: jar,
}

return client, nil
}

// doRequest sends HTTP-request without malicious payload to trigger WAF.
func (w *WAFDetector) doRequest(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, w.target, nil)
if err != nil {
return nil, errors.Wrap(err, "couldn't create request")
}

for header, value := range w.headers {
req.Header.Set(header, value)
}
req.Host = w.hostHeader

client, err := w.getHttpClient()
if err != nil {
return nil, errors.Wrap(err, "couldn't create HTTP client")
}

resp, err := client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to sent request")
}

return resp, nil
}

// doMaliciousRequest sends HTTP-request with malicious payload to trigger WAF.
func (w *WAFDetector) doMaliciousRequest(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, w.target, nil)
if err != nil {
return nil, errors.Wrap(err, "couldn't create request")
}

queryParams := req.URL.Query()
queryParams.Add("a", xssPayload)
queryParams.Add("b", sqliPayload)
Expand All @@ -101,7 +151,12 @@ func (w *WAFDetector) doRequest(ctx context.Context) (*http.Response, error) {
}
req.Host = w.hostHeader

resp, err := w.client.Do(req)
client, err := w.getHttpClient()
if err != nil {
return nil, errors.Wrap(err, "couldn't create HTTP client")
}

resp, err := client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to sent request")
}
Expand All @@ -111,19 +166,31 @@ func (w *WAFDetector) doRequest(ctx context.Context) (*http.Response, error) {

// DetectWAF performs WAF identification. Returns WAF name and vendor after
// the first positive match.
func (w *WAFDetector) DetectWAF(ctx context.Context) (name, vendor string, err error) {
func (w *WAFDetector) DetectWAF(ctx context.Context) (name, vendor string, checkFunc detectors.Check, err error) {
resp, err := w.doRequest(ctx)
if err != nil {
return "", "", errors.Wrap(err, "couldn't identify WAF")
return "", "", nil, errors.Wrap(err, "couldn't perform request without attack")
}

defer resp.Body.Close()

respToAttack, err := w.doMaliciousRequest(ctx)
if err != nil {
return "", "", nil, errors.Wrap(err, "couldn't perform request with attack")
}

defer respToAttack.Body.Close()

resps := &detectors.Responses{
Resp: resp,
RespToAttack: respToAttack,
}

for _, d := range detectors.Detectors {
if d.IsWAF(resp) {
return d.WAFName, d.Vendor, nil
if d.IsWAF(resps) {
return d.WAFName, d.Vendor, d.Check, nil
}
}

return "", "", nil
return "", "", nil, nil
}
7 changes: 4 additions & 3 deletions internal/scanner/detectors/akamai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ func KonaSiteDefender() *Detector {
Vendor: "Akamai",
}

d.Checks = []Check{
CheckHeader("Server", "AkamaiGHost"),
}
d.Check = Or(
CheckHeader("Server", "AkamaiGHost", false),
CheckHeader("Server", "AkamaiGHost", true),
)

return d
}
115 changes: 106 additions & 9 deletions internal/scanner/detectors/checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,30 @@ import (
"io"
"net/http"
"regexp"
"strings"
)

type Responses struct {
Resp *http.Response
RespToAttack *http.Response
}

// Check performs some check on the response with a fixed condition.
type Check func(resp *http.Response) bool
type Check func(resps *Responses) bool

// CheckStatusCode compare response status code with given value.
func CheckStatusCode(status int) Check {
f := func(resp *http.Response) bool {
// Default value for attack parameter is true.
func CheckStatusCode(status int, attack bool) Check {
f := func(resps *Responses) bool {
resp := resps.Resp
if attack {
resp = resps.RespToAttack
}

if resp == nil {
return false
}

if resp.StatusCode == status {
return true
}
Expand All @@ -23,11 +39,49 @@ func CheckStatusCode(status int) Check {
return f
}

// CheckReason match status reason value with regex.
// Default value for attack parameter is true.
func CheckReason(regex string, attack bool) Check {
re := regexp.MustCompile(regex)

f := func(resps *Responses) bool {
resp := resps.Resp
if attack {
resp = resps.RespToAttack
}

if resp == nil {
return false
}

reasonIndex := strings.Index(resp.Status, " ")
reason := resp.Status[reasonIndex+1:]

if re.MatchString(reason) {
return true
}

return false
}

return f
}

// CheckHeader match header value with regex.
func CheckHeader(header, regex string) Check {
// Default value for attack parameter is false.
func CheckHeader(header, regex string, attack bool) Check {
re := regexp.MustCompile(regex)

f := func(resp *http.Response) bool {
f := func(resps *Responses) bool {
resp := resps.Resp
if attack {
resp = resps.RespToAttack
}

if resp == nil {
return false
}

values := resp.Header.Values(header)
if values == nil {
return false
Expand All @@ -46,15 +100,26 @@ func CheckHeader(header, regex string) Check {
}

// CheckCookie match Set-Cookie header values with regex.
func CheckCookie(regex string) Check {
return CheckHeader("Set-Cookie", regex)
// Default value for attack parameter is false.
func CheckCookie(regex string, attack bool) Check {
return CheckHeader("Set-Cookie", regex, attack)
}

// CheckContent match body value with regex.
func CheckContent(regex string) Check {
// Default value for attack parameter is true.
func CheckContent(regex string, attack bool) Check {
re := regexp.MustCompile(regex)

f := func(resp *http.Response) bool {
f := func(resps *Responses) bool {
resp := resps.Resp
if attack {
resp = resps.RespToAttack
}

if resp == nil {
return false
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return false
Expand All @@ -73,3 +138,35 @@ func CheckContent(regex string) Check {

return f
}

// And combines the checks with AND logic,
// so each test must be true to return true.
func And(checks ...Check) Check {
f := func(resps *Responses) bool {
for _, check := range checks {
if !check(resps) {
return false
}
}

return true
}

return f
}

// Or combines the checks with OR logic,
// so at least one test must be true to return true.
func Or(checks ...Check) Check {
f := func(resps *Responses) bool {
for _, check := range checks {
if check(resps) {
return true
}
}

return false
}

return f
}
Loading

0 comments on commit 4bf549f

Please sign in to comment.