Skip to content

Commit

Permalink
Make the loadbalancers servers order random
Browse files Browse the repository at this point in the history
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
Co-authored-by: Romain <rtribotte@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 14, 2022
1 parent 89dc466 commit 788f8fa
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 46 deletions.
126 changes: 107 additions & 19 deletions integration/healthcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net/http"
"os"
"strings"
"time"

"github.com/go-check/check"
Expand Down Expand Up @@ -319,30 +320,51 @@ func (s *HealthCheckSuite) TestPropagate(c *check.C) {

try.Sleep(time.Second)

// Verify load-balancing on root still works, and that we're getting wsp2, wsp4, wsp2, wsp4, etc.
var want string
for i := 0; i < 4; i++ {
if i%2 == 0 {
want = `IP: ` + s.whoami4IP
} else {
want = `IP: ` + s.whoami2IP
}
want2 := `IP: ` + s.whoami2IP
want4 := `IP: ` + s.whoami4IP

// Verify load-balancing on root still works, and that we're getting an alternation between wsp2, and wsp4.
reachedServers := make(map[string]int)
for i := 0; i < 4; i++ {
resp, err := client.Do(rootReq)
c.Assert(err, checker.IsNil)

body, err := io.ReadAll(resp.Body)
c.Assert(err, checker.IsNil)

c.Assert(string(body), checker.Contains, want)
if reachedServers[s.whoami4IP] > reachedServers[s.whoami2IP] {
c.Assert(string(body), checker.Contains, want2)
reachedServers[s.whoami2IP]++
continue
}

if reachedServers[s.whoami2IP] > reachedServers[s.whoami4IP] {
c.Assert(string(body), checker.Contains, want4)
reachedServers[s.whoami4IP]++
continue
}

// First iteration, so we can't tell whether it's going to be wsp2, or wsp4.
if strings.Contains(string(body), `IP: `+s.whoami4IP) {
reachedServers[s.whoami4IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami2IP) {
reachedServers[s.whoami2IP]++
continue
}
}

c.Assert(reachedServers[s.whoami2IP], checker.Equals, 2)
c.Assert(reachedServers[s.whoami4IP], checker.Equals, 2)

fooReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000", nil)
c.Assert(err, checker.IsNil)
fooReq.Host = "foo.localhost"

// Verify load-balancing on foo still works, and that we're getting wsp2, wsp2, wsp2, wsp2, etc.
want = `IP: ` + s.whoami2IP
want := `IP: ` + s.whoami2IP
for i := 0; i < 4; i++ {
resp, err := client.Do(fooReq)
c.Assert(err, checker.IsNil)
Expand Down Expand Up @@ -407,43 +429,109 @@ func (s *HealthCheckSuite) TestPropagate(c *check.C) {
try.Sleep(time.Second)

// Verify everything is up on root router.
wantIPs := []string{s.whoami3IP, s.whoami1IP, s.whoami4IP, s.whoami2IP}
reachedServers = make(map[string]int)
for i := 0; i < 4; i++ {
want := `IP: ` + wantIPs[i]
resp, err := client.Do(rootReq)
c.Assert(err, checker.IsNil)

body, err := io.ReadAll(resp.Body)
c.Assert(err, checker.IsNil)

c.Assert(string(body), checker.Contains, want)
if strings.Contains(string(body), `IP: `+s.whoami1IP) {
reachedServers[s.whoami1IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami2IP) {
reachedServers[s.whoami2IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami3IP) {
reachedServers[s.whoami3IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami4IP) {
reachedServers[s.whoami4IP]++
continue
}
}

c.Assert(reachedServers[s.whoami1IP], checker.Equals, 1)
c.Assert(reachedServers[s.whoami2IP], checker.Equals, 1)
c.Assert(reachedServers[s.whoami3IP], checker.Equals, 1)
c.Assert(reachedServers[s.whoami4IP], checker.Equals, 1)

// Verify everything is up on foo router.
wantIPs = []string{s.whoami1IP, s.whoami1IP, s.whoami3IP, s.whoami2IP}
reachedServers = make(map[string]int)
for i := 0; i < 4; i++ {
want := `IP: ` + wantIPs[i]
resp, err := client.Do(fooReq)
c.Assert(err, checker.IsNil)

body, err := io.ReadAll(resp.Body)
c.Assert(err, checker.IsNil)

c.Assert(string(body), checker.Contains, want)
if strings.Contains(string(body), `IP: `+s.whoami1IP) {
reachedServers[s.whoami1IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami2IP) {
reachedServers[s.whoami2IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami3IP) {
reachedServers[s.whoami3IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami4IP) {
reachedServers[s.whoami4IP]++
continue
}
}

c.Assert(reachedServers[s.whoami1IP], checker.Equals, 2)
c.Assert(reachedServers[s.whoami2IP], checker.Equals, 1)
c.Assert(reachedServers[s.whoami3IP], checker.Equals, 1)
c.Assert(reachedServers[s.whoami4IP], checker.Equals, 0)

// Verify everything is up on bar router.
wantIPs = []string{s.whoami1IP, s.whoami1IP, s.whoami3IP, s.whoami2IP}
reachedServers = make(map[string]int)
for i := 0; i < 4; i++ {
want := `IP: ` + wantIPs[i]
resp, err := client.Do(barReq)
c.Assert(err, checker.IsNil)

body, err := io.ReadAll(resp.Body)
c.Assert(err, checker.IsNil)

c.Assert(string(body), checker.Contains, want)
if strings.Contains(string(body), `IP: `+s.whoami1IP) {
reachedServers[s.whoami1IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami2IP) {
reachedServers[s.whoami2IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami3IP) {
reachedServers[s.whoami3IP]++
continue
}

if strings.Contains(string(body), `IP: `+s.whoami4IP) {
reachedServers[s.whoami4IP]++
continue
}
}

c.Assert(reachedServers[s.whoami1IP], checker.Equals, 2)
c.Assert(reachedServers[s.whoami2IP], checker.Equals, 1)
c.Assert(reachedServers[s.whoami3IP], checker.Equals, 1)
c.Assert(reachedServers[s.whoami4IP], checker.Equals, 0)
}

func (s *HealthCheckSuite) TestPropagateNoHealthCheck(c *check.C) {
Expand Down
18 changes: 6 additions & 12 deletions integration/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ func (s *RetrySuite) TestRetry(c *check.C) {
err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 60*time.Second, try.BodyContains("PathPrefix(`/`)"))
c.Assert(err, checker.IsNil)

start := time.Now()
// This simulates a DialTimeout when connecting to the backend server.
response, err := http.Get("http://127.0.0.1:8000/")
duration, allowed := time.Since(start), time.Millisecond*250
c.Assert(err, checker.IsNil)

// The test only verifies that the retry middleware makes sure that the working service is eventually reached.
c.Assert(response.StatusCode, checker.Equals, http.StatusOK)
c.Assert(int64(duration), checker.LessThan, int64(allowed))
}

func (s *RetrySuite) TestRetryBackoff(c *check.C) {
Expand All @@ -58,16 +56,11 @@ func (s *RetrySuite) TestRetryBackoff(c *check.C) {
err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 60*time.Second, try.BodyContains("PathPrefix(`/`)"))
c.Assert(err, checker.IsNil)

start := time.Now()
// This simulates a DialTimeout when connecting to the backend server.
response, err := http.Get("http://127.0.0.1:8000/")
duration := time.Since(start)
// test case delays: 500 + 700 + 1000ms with randomization. It should be safely > 1500ms
minAllowed := time.Millisecond * 1400

c.Assert(err, checker.IsNil)

// The test only verifies that the retry middleware allows finally to reach the working service.
c.Assert(response.StatusCode, checker.Equals, http.StatusOK)
c.Assert(int64(duration), checker.GreaterThan, int64(minAllowed))
}

func (s *RetrySuite) TestRetryWebsocket(c *check.C) {
Expand All @@ -83,11 +76,12 @@ func (s *RetrySuite) TestRetryWebsocket(c *check.C) {
err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 60*time.Second, try.BodyContains("PathPrefix(`/`)"))
c.Assert(err, checker.IsNil)

// This simulates a DialTimeout when connecting to the backend server.
// The test only verifies that the retry middleware makes sure that the working service is eventually reached.
_, response, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:8000/echo", nil)
c.Assert(err, checker.IsNil)
c.Assert(response.StatusCode, checker.Equals, http.StatusSwitchingProtocols)

// The test verifies a second time that the working service is eventually reached.
_, response, err = websocket.DefaultDialer.Dial("ws://127.0.0.1:8000/echo", nil)
c.Assert(err, checker.IsNil)
c.Assert(response.StatusCode, checker.Equals, http.StatusSwitchingProtocols)
Expand Down
4 changes: 2 additions & 2 deletions integration/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (s *UDPSuite) TestWRR(c *check.C) {
stop := make(chan struct{})
go func() {
call := map[string]int{}
for i := 0; i < 4; i++ {
for i := 0; i < 8; i++ {
out, err := guessWhoUDP("127.0.0.1:8093")
c.Assert(err, checker.IsNil)
switch {
Expand All @@ -90,7 +90,7 @@ func (s *UDPSuite) TestWRR(c *check.C) {
call["unknown"]++
}
}
c.Assert(call, checker.DeepEquals, map[string]int{"whoami-a": 2, "whoami-b": 1, "whoami-c": 1})
c.Assert(call, checker.DeepEquals, map[string]int{"whoami-a": 3, "whoami-b": 2, "whoami-c": 3})
close(stop)
}()

Expand Down
15 changes: 13 additions & 2 deletions pkg/server/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
Expand Down Expand Up @@ -51,6 +52,7 @@ func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics
roundTripperManager: roundTripperManager,
balancers: make(map[string]healthcheck.Balancers),
configs: configs,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}

Expand All @@ -66,6 +68,7 @@ type Manager struct {
// which is why there is not just one Balancer per service name.
balancers map[string]healthcheck.Balancers
configs map[string]*runtime.ServiceInfo
rand *rand.Rand // For the initial shuffling of load-balancers.
}

// BuildHTTP Creates a http.Handler for a service configuration.
Expand Down Expand Up @@ -212,7 +215,7 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string,
}

balancer := wrr.New(config.Sticky, config.HealthCheck)
for _, service := range config.Services {
for _, service := range shuffle(config.Services, m.rand) {
serviceHandler, err := m.BuildHTTP(ctx, service.Name)
if err != nil {
return nil, err
Expand Down Expand Up @@ -414,7 +417,7 @@ func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, servi
func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []dynamic.Server) error {
logger := log.FromContext(ctx)

for name, srv := range servers {
for name, srv := range shuffle(servers, m.rand) {
u, err := url.Parse(srv.URL)
if err != nil {
return fmt.Errorf("error parsing server URL %s: %w", srv.URL, err)
Expand Down Expand Up @@ -443,3 +446,11 @@ func convertSameSite(sameSite string) http.SameSite {
return 0
}
}

func shuffle[T any](values []T, r *rand.Rand) []T {
shuffled := make([]T, len(values))
copy(shuffled, values)
r.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })

return shuffled
}
27 changes: 20 additions & 7 deletions pkg/server/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
type ExpectedResult struct {
StatusCode int
XFrom string
LoadBalanced bool
SecureCookie bool
HTTPOnlyCookie bool
}
Expand Down Expand Up @@ -139,12 +140,12 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
StatusCode: http.StatusOK,
LoadBalanced: true,
},
{
StatusCode: http.StatusOK,
XFrom: "second",
StatusCode: http.StatusOK,
LoadBalanced: true,
},
},
},
Expand Down Expand Up @@ -193,11 +194,9 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "first",
},
},
},
Expand Down Expand Up @@ -302,13 +301,27 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
req.Header.Set("Cookie", test.cookieRawValue)
}

var prevXFrom string
for _, expected := range test.expected {
recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, req)

assert.Equal(t, expected.StatusCode, recorder.Code)
assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From"))

if expected.XFrom != "" {
assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From"))
}

xFrom := recorder.Header().Get("X-From")
if prevXFrom != "" {
if expected.LoadBalanced {
assert.NotEqual(t, prevXFrom, xFrom)
} else {
assert.Equal(t, prevXFrom, xFrom)
}
}
prevXFrom = xFrom

cookieHeader := recorder.Header().Get("Set-Cookie")
if len(cookieHeader) > 0 {
Expand Down

0 comments on commit 788f8fa

Please sign in to comment.