diff --git a/network.go b/network.go index 347f695..8b39723 100644 --- a/network.go +++ b/network.go @@ -15,7 +15,24 @@ var ( // GetIP returns IP address from request. func (limiter *Limiter) GetIP(r *http.Request) net.IP { - if limiter.Options.TrustForwardHeader { + return GetIP(r, limiter.Options) +} + +// GetIPWithMask returns IP address from request by applying a mask. +func (limiter *Limiter) GetIPWithMask(r *http.Request) net.IP { + return GetIPWithMask(r, limiter.Options) +} + +// GetIPKey extracts IP from request and returns hashed IP to use as store key. +func (limiter *Limiter) GetIPKey(r *http.Request) string { + return limiter.GetIPWithMask(r).String() +} + +// GetIP returns IP address from request. +// If options is defined and TrustForwardHeader is true, it will lookup IP in +// X-Forwarded-For and X-Real-IP headers. +func GetIP(r *http.Request, options ...Options) net.IP { + if len(options) >= 1 && options[0].TrustForwardHeader { ip := r.Header.Get("X-Forwarded-For") if ip != "" { parts := strings.SplitN(ip, ",", 2) @@ -39,18 +56,17 @@ func (limiter *Limiter) GetIP(r *http.Request) net.IP { } // GetIPWithMask returns IP address from request by applying a mask. -func (limiter *Limiter) GetIPWithMask(r *http.Request) net.IP { - ip := limiter.GetIP(r) +func GetIPWithMask(r *http.Request, options ...Options) net.IP { + if len(options) == 0 { + return GetIP(r) + } + + ip := GetIP(r, options[0]) if ip.To4() != nil { - return ip.Mask(limiter.Options.IPv4Mask) + return ip.Mask(options[0].IPv4Mask) } if ip.To16() != nil { - return ip.Mask(limiter.Options.IPv6Mask) + return ip.Mask(options[0].IPv6Mask) } return ip } - -// GetIPKey extracts IP from request and returns hashed IP to use as store key. -func (limiter *Limiter) GetIPKey(r *http.Request) string { - return limiter.GetIPWithMask(r).String() -}