Skip to content

Commit

Permalink
feat: add in-memory opt to load threat dataset on initialization
Browse files Browse the repository at this point in the history
This commit adds a new boolean option, `InMemory`, to the `teler.Teler` middleware.
When `InMemory` is set to true, the threat dataset will be loaded into memory on
initialization, which can improve performance when running the middleware on a
distroless or runtime image. Note that this option may not be suitable for larger
datasets, as it may consume a significant amount of memory. When `InMemory` is set
to false, the threat dataset will be downloaded and stored under the user-level
cache directory on the first request, and subsequent requests will use the cached
dataset. Close #23.
  • Loading branch information
dwisiswant0 committed Mar 2, 2023
1 parent 532f387 commit dae7eeb
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 44 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.16

require (
github.com/hashicorp/go-getter v1.6.2
github.com/klauspost/compress v1.16.0 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/scorpionknifes/go-pcre v0.0.0-20210805092536-77486363b797
github.com/stretchr/testify v1.8.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/klauspost/compress v1.11.2 h1:MiK62aErc3gIiVEtyzKfeOHgW7atJb5g/KNX5m3c2nQ=
github.com/klauspost/compress v1.11.2/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.16.0 h1:iULayQNOReoYUe+1qtKOqw9CwJv3aNQu8ivo7lw1HU4=
github.com/klauspost/compress v1.16.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
Expand Down
35 changes: 27 additions & 8 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@ import "github.com/kitabisa/teler-waf/threat"
// Options is a struct for specifying configuration options for the teler.Teler middleware.
type Options struct {
// Excludes is a list of threat types to exclude from the security checks.
//
// These threat types are defined in the threat.Threat type.
Excludes []threat.Threat

// Whitelists is a list of regular expressions that match request elements
// that should be excluded from the security checks. The request elements
// that can be matched are request URI (path and query parameters), HTTP headers,
// or client IP address.
// that should be excluded from the security checks.
//
// The request elements that can be matched are request URI (path and query parameters),
// HTTP headers, or client IP address.
Whitelists []string

// Customs is a list of custom security rules to apply to incoming requests.
//
// These rules can be used to create custom security checks or to override
// the default security checks provided by teler-waf.
Customs []Rule

// LogFile is the file path for the log file to store the security logs.
//
// If LogFile is specified, log messages will be written to the specified
// file in addition to stderr (if NoStderr is false).
LogFile string
Expand All @@ -29,17 +33,32 @@ type Options struct {
// LogRotate bool

// NoStderr is a boolean flag indicating whether or not to suppress log messages
// from being printed to the standard error (stderr) stream. When set to true, log messages
// will not be printed to stderr. If set to false, log messages will be printed to stderr.
// By default, log messages are printed to stderr (false).
// from being printed to the standard error (stderr) stream.
//
// When set to true, log messages will not be printed to stderr. If set to false,
// log messages will be printed to stderr. By default, log messages are printed
// to stderr (false).
NoStderr bool

// NoUpdateCheck is a boolean flag indicating whether or not to disable automatic threat
// dataset updates. When set to true, automatic updates will be disabled. If set to false,
// automatic updates will be enabled. By default, automatic updates are enabled (false).
// dataset updates.
//
// When set to true, automatic updates will be disabled. If set to false, automatic
// updates will be enabled. By default, automatic updates are enabled (false).
NoUpdateCheck bool

// Development is a boolean flag that determines whether the request is cached or not.
//
// By default, development mode is disabled (false) or requests will cached.
Development bool

// InMemory is a boolean flag that specifies whether or not to load the threat dataset
// into memory on initialization.
//
// When set to true, the threat dataset will be loaded into memory, which can be useful
// when running your service or application on a distroless or runtime image, where file
// access may be limited or slow. If InMemory is set to false, the threat dataset will
// be downloaded and stored under the user-level cache directory on the first startup.
// Subsequent startups will use the cached dataset.
InMemory bool
}
9 changes: 6 additions & 3 deletions request/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package request
import "net/http"

// Method is a type alias for string, used to represent HTTP methods.
//
// It is defined as a type alias to allow for custom methods to be added
// in the future, while still maintaining type-safety.
type Method string

// Constants representing common HTTP methods. These constants are of type Method
// and are assigned the values of the corresponding HTTP methods from the net/http package.
// Using these constants allows users of the request package to specify HTTP methods
// Constants representing common HTTP methods.
//
// These constants are of type Method and are assigned the values of the
// corresponding HTTP methods from the net/http package. Using these
// constants allows users of the request package to specify HTTP methods
// in a type-safe manner, rather than using raw strings.
const (
GET Method = http.MethodGet // GET is the HTTP GET method.
Expand Down
124 changes: 94 additions & 30 deletions teler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import (
"strings"
"time"

"archive/tar"
"encoding/json"
"net/http"
"net/url"

"github.com/kitabisa/teler-waf/request"
"github.com/kitabisa/teler-waf/threat"
"github.com/klauspost/compress/zstd"
"github.com/patrickmn/go-cache"
"github.com/scorpionknifes/go-pcre"
"github.com/valyala/fastjson"
Expand Down Expand Up @@ -90,6 +92,9 @@ func New(opts ...Options) *Teler {
threat: &Threat{},
}

// Set the opt field of the Teler struct to the options
t.opt = o

// Retrieve the data for each threat category
err := t.getResources()
if err != nil {
Expand Down Expand Up @@ -204,13 +209,11 @@ func New(opts ...Options) *Teler {
t.cache = cache.New(15*time.Minute, 20*time.Minute)
}

// Set the opt field of the Teler struct to the options
t.opt = o

return t
}

// postAnalyze is a function that processes the HTTP response after an error is returned from the analyzeRequest function.
// postAnalyze is a function that processes the HTTP response after
// an error is returned from the analyzeRequest function.
func (t *Teler) postAnalyze(w http.ResponseWriter, r *http.Request, k threat.Threat, err error) {
// If there is no error, return early.
if err == nil {
Expand Down Expand Up @@ -268,50 +271,111 @@ func (t *Teler) getResources() error {
}

// Download the datasets of threat ruleset from teler-resources
// if threat datasets is not up-to-date and NoUpdateCheck is false
if !updated && !t.opt.NoUpdateCheck {
// if threat datasets is not up-to-date, update check is disabled
// and in-memory option is true
if !updated && !t.opt.NoUpdateCheck && !t.opt.InMemory {
if err := threat.Get(); err != nil {
return err
}
}

// Initialize files for in-memory threat datasets
files := make(map[string][]byte, 0)

// If the Threat struct was configured to load data into memory, retrieve the threat data
// from the DB URL and uncompress it from Zstandard format, then extract the contents of
// each file from the tar archive and store them in a map indexed by their file name
if t.opt.InMemory {
resp, err := http.Get(threat.DbURL)
if err != nil {
return err
}
defer resp.Body.Close()

zstdReader, err := zstd.NewReader(resp.Body)
if err != nil {
return err
}
defer zstdReader.Close()

tarReader := tar.NewReader(zstdReader)

for {
// Read the next header from the tar archive
header, err := tarReader.Next()
if err == io.EOF {
break
}

if err != nil {
return err
}

// Skip non-regular files
if header.Typeflag != tar.TypeReg {
continue
}

// Read the contents of the file
fileContent, err := io.ReadAll(tarReader)
if err != nil {
return err
}

// Store the file content in the map indexed by the file name
files[header.Name] = fileContent
}
}

// Initialize the data field of the Threat struct to a new map
// that will be used to store the threat data
t.threat.data = make(map[threat.Threat]string)

for _, k := range threat.List() {
// Get the location of respective threat type
path, err := k.Filename(true)
// Initialize error & threat dataset content variables
var err error
var b []byte

// Get the file name and the path of respective threat type
path, err := k.Filename(!t.opt.InMemory)
if err != nil {
return err
}

// Read the contents of the data file at the specified path and store it
// as a string in the data field of the Threat struct. If the file is not
// found, the function will attempt to retrieve the threat from an external
// source using the `Get()` method on the `threat` object. If the threat
// retrieval fails, an error will be returned. Otherwise, the function will
// retry reading the file as usual. If any other error occurs while reading
// the file, it will be returned immediately.
b, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
// If the error is a file not found error, attempt to retrieve the
// threat from an external source using the `Get()` method on the
// `threat` object.
if err := threat.Get(); err != nil {
return err
}
// If the data is loaded in memory, retrieve it from the files map. Otherwise,
// read the contents of the data file at the specified path and store it as a
// string in the data field of the Threat struct. If the file is not found,
// the function will attempt to retrieve the threat from an external source
// using the `Get()` method on the `threat` object. If the threat retrieval
// fails, an error will be returned. Otherwise, the function will retry reading
// the file as usual. If any other error occurs while reading the file, it will
// be returned immediately.
if t.opt.InMemory {
b = files[path]
} else {
b, err = os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
// If the error is a file not found error, attempt to retrieve the
// threat from an external source using the `Get()` method on the
// `threat` object.
if err := threat.Get(); err != nil {
return err
}

// Retry reading the file after retrieving the threat.
b, err = os.ReadFile(path)
if err != nil {
// Retry reading the file after retrieving the threat.
b, err = os.ReadFile(path)
if err != nil {
return err
}
} else {
// If the error is not a file not found error, return it immediately.
return err
}
} else {
// If the error is not a file not found error, return it immediately.
return err
}
}

// Store the threat dataset contents in Threat struct as a string
t.threat.data[k] = string(b)

err = t.processResource(k)
Expand Down
21 changes: 21 additions & 0 deletions teler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ func TestNewWithMalformedDataset(t *testing.T) {
}
}

func TestNewWithInMemory(t *testing.T) {
// Initialize teler
telerMiddleware := New(Options{NoStderr: true, InMemory: true})
wrappedHandler := telerMiddleware.Handler(handler)

// Create a test server with the wrapped handler
ts := httptest.NewServer(wrappedHandler)
defer ts.Close()

// Create a request to send to the test server
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}

_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
}

func TestNewCustom(t *testing.T) {
// Initialize teler
telerMiddleware := New(Options{
Expand Down
2 changes: 0 additions & 2 deletions threat/const.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package threat

import "fmt"

const (
repoURL = "https://github.com/kitabisa/teler-resources"
cachePath = "/teler-waf"
Expand Down
3 changes: 2 additions & 1 deletion threat/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
)

// Get retrieves all the teler threat datasets.
//
// It returns an error if there was an issue when retrieving the datasets.
func Get() error {
// Get the destination location for the datasets
Expand All @@ -34,7 +35,7 @@ func Get() error {
}

// Retrieve the compressed archive DB file from the GitHub repository using go-getter
err = getter.Get(dst, fmt.Sprintf("%s?%s", dbURL, dbQuery))
err = getter.Get(dst, fmt.Sprintf("%s?%s", DbURL, dbQuery))
if err != nil {
// If there was an error retrieving the files, return the error
return err
Expand Down
2 changes: 2 additions & 0 deletions threat/var.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package threat

import "fmt"

var (
DbURL = fmt.Sprintf("%s/raw/master/db/db.tar.zst", repoURL)
dbQuery = fmt.Sprintf("checksum=file:%s/raw/master/db/MD5SUMS", repoURL)
Expand Down

0 comments on commit dae7eeb

Please sign in to comment.