diff --git a/cluster/advertise.go b/cluster/advertise.go index fb71cf0170..c6f44b927b 100644 --- a/cluster/advertise.go +++ b/cluster/advertise.go @@ -20,6 +20,11 @@ import ( "github.com/pkg/errors" ) +type getPrivateIPFunc func() (string, error) + +// This is overriden in unit tests to mock the sockaddr.GetPrivateIP function. +var getPrivateAddress getPrivateIPFunc = sockaddr.GetPrivateIP + // calculateAdvertiseAddress attempts to clone logic from deep within memberlist // (NetTransport.FinalAdvertiseAddr) in order to surface its conclusions to the // application, so we can provide more actionable error messages if the user has @@ -39,12 +44,12 @@ func calculateAdvertiseAddress(bindAddr, advertiseAddr string) (net.IP, error) { } if isAny(bindAddr) { - privateIP, err := sockaddr.GetPrivateIP() + privateIP, err := getPrivateAddress() if err != nil { return nil, errors.Wrap(err, "failed to get private IP") } if privateIP == "" { - return nil, errors.Wrap(err, "no private IP found, explicit advertise addr not provided") + return nil, errors.New("no private IP found, explicit advertise addr not provided") } ip := net.ParseIP(privateIP) if ip == nil { diff --git a/cluster/advertise_test.go b/cluster/advertise_test.go new file mode 100644 index 0000000000..dc6b02e899 --- /dev/null +++ b/cluster/advertise_test.go @@ -0,0 +1,92 @@ +// Copyright 2018 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cluster + +import ( + "errors" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCalculateAdvertiseAddress(t *testing.T) { + old := getPrivateAddress + defer func() { + getPrivateAddress = old + }() + + cases := []struct { + fn getPrivateIPFunc + bind, advertise string + + expectedIP net.IP + err bool + }{ + { + bind: "192.0.2.1", + advertise: "", + + expectedIP: net.ParseIP("192.0.2.1"), + err: false, + }, + { + bind: "192.0.2.1", + advertise: "192.0.2.2", + + expectedIP: net.ParseIP("192.0.2.2"), + err: false, + }, + { + fn: func() (string, error) { return "192.0.2.1", nil }, + bind: "0.0.0.0", + advertise: "", + + expectedIP: net.ParseIP("192.0.2.1"), + err: false, + }, + { + fn: func() (string, error) { return "", errors.New("some error") }, + bind: "0.0.0.0", + advertise: "", + + err: true, + }, + { + fn: func() (string, error) { return "invalid", nil }, + bind: "0.0.0.0", + advertise: "", + + err: true, + }, + { + fn: func() (string, error) { return "", nil }, + bind: "0.0.0.0", + advertise: "", + + err: true, + }, + } + + for _, c := range cases { + getPrivateAddress = c.fn + got, err := calculateAdvertiseAddress(c.bind, c.advertise) + if c.err { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, c.expectedIP.String(), got.String()) + } + } +}