Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added additional variables for network + simplified logic #1282

Merged
merged 9 commits into from
Nov 30, 2021
16 changes: 3 additions & 13 deletions v2/pkg/protocols/network/network.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package network

import (
"net"
"strings"

"github.com/pkg/errors"
Expand Down Expand Up @@ -83,9 +82,8 @@ type Request struct {
}

type addressKV struct {
ip string
port string
tls bool
address string
tls bool
}

// Input is the input to send on the network
Expand Down Expand Up @@ -141,15 +139,7 @@ func (request *Request) Compile(options *protocols.ExecuterOptions) error {
shouldUseTLS = true
address = strings.TrimPrefix(address, "tls://")
}
if strings.Contains(address, ":") {
addressHost, addressPort, portErr := net.SplitHostPort(address)
if portErr != nil {
return errors.Wrap(portErr, "could not parse address")
}
request.addresses = append(request.addresses, addressKV{ip: addressHost, port: addressPort, tls: shouldUseTLS})
} else {
request.addresses = append(request.addresses, addressKV{ip: address, tls: shouldUseTLS})
}
request.addresses = append(request.addresses, addressKV{address: address, tls: shouldUseTLS})
}
// Pre-compile any input dsl functions before executing the request.
for _, input := range request.Inputs {
Expand Down
15 changes: 3 additions & 12 deletions v2/pkg/protocols/network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestNetworkCompileMake(t *testing.T) {
templateID := "testing-network"
request := &Request{
ID: templateID,
Address: []string{"{{Hostname}}", "{{Hostname}}:8082", "tls://{{Hostname}}:443"},
Address: []string{"tls://{{Host}}:443"},
ReadSize: 1024,
Inputs: []*Input{{Data: "test-data"}},
}
Expand All @@ -28,17 +28,8 @@ func TestNetworkCompileMake(t *testing.T) {
err := request.Compile(executerOpts)
require.Nil(t, err, "could not compile network request")

require.Equal(t, 3, len(request.addresses), "could not get correct number of input address")
t.Run("check-host", func(t *testing.T) {
require.Equal(t, "{{Hostname}}", request.addresses[0].ip, "could not get correct host")
})
t.Run("check-host-with-port", func(t *testing.T) {
require.Equal(t, "{{Hostname}}", request.addresses[1].ip, "could not get correct host with port")
require.Equal(t, "8082", request.addresses[1].port, "could not get correct port for host")
})
require.Equal(t, 1, len(request.addresses), "could not get correct number of input address")
t.Run("check-tls-with-port", func(t *testing.T) {
require.Equal(t, "{{Hostname}}", request.addresses[2].ip, "could not get correct host with port")
require.Equal(t, "443", request.addresses[2].port, "could not get correct port for host")
require.True(t, request.addresses[2].tls, "could not get correct port for host")
require.True(t, request.addresses[0].tls, "could not get correct port for host")
})
}
41 changes: 24 additions & 17 deletions v2/pkg/protocols/network/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,10 @@ func (request *Request) ExecuteWithResults(input string, metadata /*TODO review
}

for _, kv := range request.addresses {
actualAddress := replacer.Replace(kv.ip, map[string]interface{}{"Hostname": address})
if kv.port != "" {
if strings.Contains(address, ":") {
actualAddress, _, _ = net.SplitHostPort(actualAddress)
}
actualAddress = net.JoinHostPort(actualAddress, kv.port)
}
if input != "" {
input = actualAddress
}
variables := generateNetworkVariables(address)
actualAddress := replacer.Replace(kv.address, variables)

if err := request.executeAddress(actualAddress, address, input, kv.tls, previous, callback); err != nil {
if err := request.executeAddress(variables, actualAddress, address, input, kv.tls, previous, callback); err != nil {
gologger.Verbose().Label("ERR").Msgf("Could not make network request for %s: %s\n", actualAddress, err)
continue
}
Expand All @@ -69,7 +61,7 @@ func (request *Request) ExecuteWithResults(input string, metadata /*TODO review
}

// executeAddress executes the request for an address
func (request *Request) executeAddress(actualAddress, address, input string, shouldUseTLS bool, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
func (request *Request) executeAddress(variables map[string]interface{}, actualAddress, address, input string, shouldUseTLS bool, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
if !strings.Contains(actualAddress, ":") {
err := errors.New("no port provided in network protocol request")
request.options.Output.Request(request.options.TemplatePath, address, request.Type().String(), err)
Expand All @@ -88,27 +80,27 @@ func (request *Request) executeAddress(actualAddress, address, input string, sho
break
}
value = generators.MergeMaps(value, payloads)
if err := request.executeRequestWithPayloads(actualAddress, address, input, shouldUseTLS, value, previous, callback); err != nil {
if err := request.executeRequestWithPayloads(variables, actualAddress, address, input, shouldUseTLS, value, previous, callback); err != nil {
return err
}
}
} else {
value := generators.MergeMaps(map[string]interface{}{}, payloads)
if err := request.executeRequestWithPayloads(actualAddress, address, input, shouldUseTLS, value, previous, callback); err != nil {
value := generators.CopyMap(payloads)
if err := request.executeRequestWithPayloads(variables, actualAddress, address, input, shouldUseTLS, value, previous, callback); err != nil {
return err
}
}
return nil
}

func (request *Request) executeRequestWithPayloads(actualAddress, address, input string, shouldUseTLS bool, payloads map[string]interface{}, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
func (request *Request) executeRequestWithPayloads(variables map[string]interface{}, actualAddress, address, input string, shouldUseTLS bool, payloads map[string]interface{}, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
var (
hostname string
conn net.Conn
err error
)

request.dynamicValues = generators.MergeMaps(payloads, map[string]interface{}{"Hostname": address})
request.dynamicValues = generators.MergeMaps(payloads, variables)

if host, _, splitErr := net.SplitHostPort(actualAddress); splitErr == nil {
hostname = host
Expand Down Expand Up @@ -328,3 +320,18 @@ func getAddress(toTest string) (string, error) {
}
return toTest, nil
}

func generateNetworkVariables(input string) map[string]interface{} {
if !strings.Contains(input, ":") {
return map[string]interface{}{"Hostname": input, "Host": input}
}
host, port, err := net.SplitHostPort(input)
if err != nil {
return map[string]interface{}{"Hostname": input}
}
return map[string]interface{}{
"Host": host,
"Port": port,
"Hostname": input,
}
}
9 changes: 2 additions & 7 deletions v2/pkg/protocols/network/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {

parsed, err := url.Parse(ts.URL)
require.Nil(t, err, "could not parse url")
request.Address[0] = "{{Hostname}}:" + parsed.Port()
request.Address[0] = "{{Hostname}}"

request.Inputs = append(request.Inputs, &Input{Data: fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\n\r\n", parsed.Host)})
executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
Expand Down Expand Up @@ -84,12 +84,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
})
require.Nil(t, err, "could not execute network request")
})
require.NotNil(t, finalEvent, "could not get event output from request")
require.Equal(t, 1, len(finalEvent.Results), "could not get correct number of results")
require.Equal(t, "test", finalEvent.Results[0].MatcherName, "could not get correct matcher name of results")
require.Equal(t, 1, len(finalEvent.Results[0].ExtractedResults), "could not get correct number of extracted results")
require.Equal(t, "<h1>Example Domain</h1>", finalEvent.Results[0].ExtractedResults[0], "could not get correct extracted results")
finalEvent = nil
require.Nil(t, finalEvent, "could not get event output from request")

request.Inputs[0].Type = NetworkInputTypeHolder{NetworkInputType: hexType}
request.Inputs[0].Data = hex.EncodeToString([]byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\n\r\n", parsed.Host)))
Expand Down