Skip to content

Commit

Permalink
Added new variables feature for configuration and header controls
Browse files Browse the repository at this point in the history
inject headers now and configure basic auth easily.

Signed-off-by: Dave Shanley <dave@quobix.com>
  • Loading branch information
daveshanley committed Jul 14, 2023
1 parent f2571be commit e400aed
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 73 deletions.
69 changes: 53 additions & 16 deletions cmd/root_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ var (
}

var config shared.WiretapConfiguration

if configFlag == "" {
// see if a configuration file exists in the current directory or in the user's home directory.
local, _ := os.Stat("wiretap.yaml")
home, _ := os.Stat(filepath.Join(os.Getenv("HOME"), "wiretap.yaml"))
if home != nil {
configFlag = filepath.Join(os.Getenv("HOME"), "wiretap.yaml")
}
if local != nil {
configFlag = local.Name()
}
}

if configFlag != "" {

cBytes, err := os.ReadFile(configFlag)
Expand All @@ -110,6 +123,7 @@ var (
config.StaticIndex = staticIndex
}
} else {

pterm.Info.Println("No wiretap configuration located. Using defaults")
config.StaticIndex = staticIndex
}
Expand Down Expand Up @@ -151,7 +165,9 @@ var (
redirectScheme = parsedURL.Scheme
redirectBasePath = parsedURL.Path

config.Contract = spec
if spec != "" {
config.Contract = spec
}
config.RedirectURL = redirectURL
config.RedirectHost = redirectHost
config.RedirectBasePath = redirectBasePath
Expand All @@ -175,16 +191,23 @@ var (
}
config.FS = FS

// variables
if len(config.Variables) > 0 {
config.CompileVariables()
printLoadedVariables(config.Variables)
}

// paths
if len(config.PathConfigurations) > 0 {
printLoadedPathConfigurations(config.PathConfigurations)
config.CompilePaths()
printLoadedPathConfigurations(config.PathConfigurations)
}

if config.Headers != nil && len(config.Headers.DropHeaders) > 0 {
pterm.Info.Printf("Dropping the following %d %s:\n", len(config.Headers.DropHeaders),
pterm.Info.Printf("Dropping the following %d %s globally:\n", len(config.Headers.DropHeaders),
shared.Pluralize(len(config.Headers.DropHeaders), "header", "headers"))
for _, header := range config.Headers.DropHeaders {
pterm.Printf("🗑️ %s\n", pterm.LightMagenta(header))
pterm.Printf("🗑️ %s\n", pterm.LightRed(header))
}
pterm.Println()
}
Expand All @@ -194,7 +217,7 @@ var (
pterm.Info.Printf("Mapping %d static %s to '%s':\n", len(config.StaticPaths),
shared.Pluralize(len(config.StaticPaths), "path", "paths"), staticPath)
for _, path := range config.StaticPaths {
pterm.Printf("⛱️ %s\n", pterm.LightMagenta(path))
pterm.Printf("⛱️ %s\n", pterm.LightMagenta(path))
}
pterm.Println()
}
Expand Down Expand Up @@ -238,21 +261,35 @@ func Execute(version, commit, date string, fs embed.FS) {
}

func printLoadedPathConfigurations(configs map[string]*shared.WiretapPathConfig) {
plural := func(count int) string {
if count == 1 {
return ""
}
return "s"
}

pterm.Info.Printf("Loaded %d path configuration%s:\n", len(configs), plural(len(configs)))
pterm.Info.Printf("Loaded %d path %s:\n", len(configs),
shared.Pluralize(len(configs), "configuration", "configurations"))
pterm.Println()

for k, v := range configs {
pterm.Printf("%s\n", pterm.LightMagenta(k))
for k, p := range v.PathRewrite {
pterm.Printf("✏️ '%s' re-written to '%s'\n", pterm.LightCyan(k), pterm.LightGreen(p))
pterm.Printf("%s --> %s\n", pterm.LightMagenta(k), pterm.LightCyan(v.Target))
for ka, p := range v.PathRewrite {
pterm.Printf("✏️ '%s' re-written to '%s'\n", pterm.LightCyan(ka), pterm.LightGreen(p))
}
if v.Headers != nil {
for kb, h := range v.Headers.InjectHeaders {
pterm.Printf("💉 '%s' injected with '%s'\n", pterm.LightCyan(kb), pterm.LightGreen(h))
}
for _, h := range v.Headers.DropHeaders {
pterm.Printf("🗑️ '%s' is being %s\n", pterm.LightCyan(h), pterm.LightRed("dropped"))
}
}
if v.Auth != "" {
pterm.Printf("🔐 Basic authentication implemented for '%s'\n", pterm.LightMagenta(k))
}
pterm.Println()
}
}

func printLoadedVariables(variables map[string]string) {
pterm.Info.Printf("Loaded %d %s:\n", len(variables),
shared.Pluralize(len(variables), "variable", "variables"))
for k, v := range variables {
pterm.Printf("📌 Variable '${%s}' points to '%s'\n", pterm.LightCyan(k), pterm.LightGreen(v))
}
pterm.Println()
}
87 changes: 53 additions & 34 deletions config/paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,63 @@
package config

import (
"fmt"
"github.com/pb33f/wiretap/shared"
"fmt"
"github.com/pb33f/wiretap/shared"
"strings"
)

func FindPaths(path string, configuration *shared.WiretapConfiguration) []*shared.WiretapPathConfig {
var foundConfigurations []*shared.WiretapPathConfig
for key := range configuration.CompiledPaths {
if configuration.CompiledPaths[key].CompiledKey.Match(path) {
foundConfigurations = append(foundConfigurations, configuration.CompiledPaths[key].PathConfig)
}
}
return foundConfigurations
var foundConfigurations []*shared.WiretapPathConfig
for key := range configuration.CompiledPaths {
if configuration.CompiledPaths[key].CompiledKey.Match(path) {
foundConfigurations = append(foundConfigurations, configuration.CompiledPaths[key].PathConfig)
}
}
return foundConfigurations
}

func RewritePath(path string, configuration *shared.WiretapConfiguration) string {
paths := FindPaths(path, configuration)
var replaced string
if len(paths) > 0 {
// extract first path
pathConfig := paths[0]
replaced = ""
for key := range pathConfig.CompiledPath.CompiledPathRewrite {
if pathConfig.CompiledPath.CompiledPathRewrite[key].MatchString(path) {
replace := pathConfig.PathRewrite[key]
rex := pathConfig.CompiledPath.CompiledPathRewrite[key]
replacedPath := rex.ReplaceAllString(path, replace)

scheme := "http://"
if pathConfig.Secure {
scheme = "https://"
}
if replacedPath[0] != '/' && pathConfig.Target[len(pathConfig.Target)-1] != '/' {
replacedPath = fmt.Sprintf("/%s", replacedPath)
}
replaced = fmt.Sprintf("%s%s%s", scheme, pathConfig.Target, replacedPath)
break
}
}
}
return replaced
paths := FindPaths(path, configuration)
var replaced string
if len(paths) > 0 {
// extract first path
pathConfig := paths[0]
replaced = ""
for key := range pathConfig.CompiledPath.CompiledPathRewrite {
if pathConfig.CompiledPath.CompiledPathRewrite[key].MatchString(path) {
replace := pathConfig.PathRewrite[key]
rex := pathConfig.CompiledPath.CompiledPathRewrite[key]
replacedPath := rex.ReplaceAllString(path, replace)

scheme := "http://"
if pathConfig.Secure {
scheme = "https://"
}
if replacedPath[0] != '/' && pathConfig.Target[len(pathConfig.Target)-1] != '/' {
replacedPath = fmt.Sprintf("/%s", replacedPath)
}
target := strings.ReplaceAll(strings.ReplaceAll(configuration.ReplaceWithVariables(pathConfig.Target),
"http://", ""), "https://", "")

replaced = fmt.Sprintf("%s%s%s", scheme, target, replacedPath)
break
}
}
// no rewriting, just replace target.
if replaced == "" {
scheme := "http://"
if pathConfig.Secure {
scheme = "https://"
}
target := strings.ReplaceAll(strings.ReplaceAll(configuration.ReplaceWithVariables(pathConfig.Target),
"http://", ""), "https://", "")

if path[0] != '/' && pathConfig.Target[len(pathConfig.Target)-1] != '/' {
path = fmt.Sprintf("/%s", path)
}
replaced = fmt.Sprintf("%s%s%s", scheme, target, path)
}
}

return replaced
}
54 changes: 44 additions & 10 deletions daemon/handle_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/pb33f/libopenapi-validator/responses"
"github.com/pb33f/ranch/model"
"github.com/pb33f/ranch/plank/utils"
configModel "github.com/pb33f/wiretap/config"
"github.com/pb33f/wiretap/shared"
"io"
"net/http"
Expand Down Expand Up @@ -100,20 +101,53 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) {
}
}

var dropHeaders []string
var injectHeaders map[string]string

// add global headers with injection.
if config.Headers != nil {
dropHeaders = config.Headers.DropHeaders
injectHeaders = config.Headers.InjectHeaders
}

// now add path specific headers.
matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config)
auth := ""
if len(matchedPaths) > 0 {
for _, path := range matchedPaths {
auth = path.Auth
if path.Headers != nil {
dropHeaders = append(dropHeaders, path.Headers.DropHeaders...)
newInjectHeaders := path.Headers.InjectHeaders
for key := range injectHeaders {
newInjectHeaders[key] = injectHeaders[key]
}
injectHeaders = newInjectHeaders
}
break
}
}

newReq := cloneRequest(CloneRequest{
Request: request.HttpRequest,
Protocol: config.RedirectProtocol,
Host: config.RedirectHost,
Port: config.RedirectPort,
DropHeaders: config.Headers.DropHeaders,
Request: request.HttpRequest,
Protocol: config.RedirectProtocol,
Host: config.RedirectHost,
Port: config.RedirectPort,
DropHeaders: dropHeaders,
InjectHeaders: injectHeaders,
Auth: auth,
Variables: config.CompiledVariables,
})

apiRequest := cloneRequest(CloneRequest{
Request: request.HttpRequest,
Protocol: config.RedirectProtocol,
Host: config.RedirectHost,
Port: config.RedirectPort,
DropHeaders: config.Headers.DropHeaders,
Request: request.HttpRequest,
Protocol: config.RedirectProtocol,
Host: config.RedirectHost,
Port: config.RedirectPort,
DropHeaders: dropHeaders,
InjectHeaders: injectHeaders,
Auth: auth,
Variables: config.CompiledVariables,
})

// validate the request
Expand Down
45 changes: 38 additions & 7 deletions daemon/wiretap_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ package daemon

import (
"bytes"
"encoding/base64"
"fmt"
"github.com/pb33f/wiretap/shared"
"io"
"net/http"
"strings"
Expand Down Expand Up @@ -34,11 +36,15 @@ func reconstructURL(r *http.Request, protocol, host, port string) string {
}

type CloneRequest struct {
Request *http.Request
Protocol string
Host string
Port string
DropHeaders []string
Request *http.Request
Protocol string
Host string
Port string
PathTarget string
DropHeaders []string
InjectHeaders map[string]string
Auth string
Variables map[string]*shared.CompiledVariable
}

func cloneRequest(request CloneRequest) *http.Request {
Expand All @@ -47,9 +53,12 @@ func cloneRequest(request CloneRequest) *http.Request {
_ = request.Request.Body.Close()
request.Request.Body = io.NopCloser(bytes.NewBuffer(b))

var newURL string
var newReq *http.Request
newURL = reconstructURL(request.Request, request.Protocol, request.Host, request.Port)

// create cloned request
newURL := reconstructURL(request.Request, request.Protocol, request.Host, request.Port)
newReq, _ := http.NewRequest(request.Request.Method, newURL, io.NopCloser(bytes.NewBuffer(b)))
newReq, _ = http.NewRequest(request.Request.Method, newURL, io.NopCloser(bytes.NewBuffer(b)))

// copy headers, drop those that are specified.
for k, v := range request.Request.Header {
Expand All @@ -63,9 +72,31 @@ func cloneRequest(request CloneRequest) *http.Request {
newReq.Header.Set(k, v[0])
}
}

// inject headers
for k, v := range request.InjectHeaders {
newReq.Header.Set(k, ReplaceWithVariables(request.Variables, v))
}

// if the auth value is set, we need to base64 encode it and add it to the header.
if request.Auth != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(request.Auth))
// this will overwrite any existing auth header.
newReq.Header.Set("Authorization", fmt.Sprintf("Basic %s", encoded))
}

return newReq
}

func ReplaceWithVariables(variables map[string]*shared.CompiledVariable, input string) string {
for x := range variables {
if variables[x] != nil {
input = variables[x].CompiledVariable.ReplaceAllString(input, variables[x].VariableValue)
}
}
return input
}

func cloneResponse(r *http.Response) *http.Response {
// sniff and replace body.
var b []byte
Expand Down
Loading

0 comments on commit e400aed

Please sign in to comment.