Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loadbalance strategy extension #230

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions roundrobin/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ func RebalancerDebug(debug bool) RebalancerOption {
}

// ServerOption provides various options for server, e.g. weight.
type ServerOption func(*server) error
type ServerOption func(s Server) error

// Weight is an optional functional argument that sets weight of the server.
func Weight(w int) ServerOption {
return func(s *server) error {
return func(s Server) error {
if w < 0 {
return fmt.Errorf("Weight should be >= 0")
return fmt.Errorf("Weight should be >= 0 ")
}
s.weight = w
s.Set(w)
return nil
}
}
Expand Down
4 changes: 2 additions & 2 deletions roundrobin/rebalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type BalancerHandler interface {
ServerWeight(u *url.URL) (int, bool)
RemoveServer(u *url.URL) error
UpsertServer(u *url.URL, options ...ServerOption) error
NextServer() (*url.URL, error)
NextServer(w http.ResponseWriter, req *http.Request, neq *http.Request) (*url.URL, error)
Next() http.Handler
}

Expand Down Expand Up @@ -144,7 +144,7 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

if !stuck {
fwdURL, err := rb.next.NextServer()
fwdURL, err := rb.next.NextServer(w, req, &newReq)
if err != nil {
rb.errHandler.ServeHTTP(w, req, err)
return
Expand Down
8 changes: 4 additions & 4 deletions roundrobin/rebalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ func TestRebalancerRecovery(t *testing.T) {
assert.Equal(t, 1, rb.servers[0].curWeight)
assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight)

assert.Equal(t, 1, lb.servers[0].weight)
assert.Equal(t, FSMMaxWeight, lb.servers[1].weight)
assert.Equal(t, 1, lb.servers[0].Weight())
assert.Equal(t, FSMMaxWeight, lb.servers[1].Weight())

// server a is now recovering, the weights should go back to the original state
rb.servers[0].meter.(*testMeter).rating = 0
Expand All @@ -141,8 +141,8 @@ func TestRebalancerRecovery(t *testing.T) {
assert.Equal(t, 1, rb.servers[1].curWeight)

// Make sure we have applied the weights to the inner load balancer
assert.Equal(t, 1, lb.servers[0].weight)
assert.Equal(t, 1, lb.servers[1].weight)
assert.Equal(t, 1, lb.servers[0].Weight())
assert.Equal(t, 1, lb.servers[1].Weight())
}

// Test scenario when increaing the weight on good endpoints made it worse.
Expand Down
49 changes: 33 additions & 16 deletions roundrobin/rr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package roundrobin

import (
"fmt"
"math/rand"
"net/http"
"net/url"
"sync"
Expand All @@ -17,7 +18,7 @@ type RoundRobin struct {
errHandler utils.ErrorHandler
// Current index (starts from -1)
index int
servers []*server
servers []Server
currentWeight int
stickySession *StickySession
requestRewriteListener RequestRewriteListener
Expand All @@ -32,7 +33,7 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
next: next,
index: -1,
mutex: &sync.Mutex{},
servers: []*server{},
servers: []Server{},
stickySession: nil,

log: &utils.NoopLogger{},
Expand Down Expand Up @@ -76,7 +77,7 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

if !stuck {
uri, err := r.NextServer()
uri, err := r.NextServer(w, req, &newReq)
if err != nil {
r.errHandler.ServeHTTP(w, req, err)
return
Expand All @@ -103,15 +104,19 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

// NextServer gets the next server.
func (r *RoundRobin) NextServer() (*url.URL, error) {
srv, err := r.nextServer()
func (r *RoundRobin) NextServer(w http.ResponseWriter, req *http.Request, neq *http.Request) (*url.URL, error) {
// Use extension balance server, if extension return multiple servers, choose anyone.
if ss := Strategy().Next(w, req, neq, r.servers); len(ss) > 0 && (len(ss) < len(r.servers) || len(r.servers) < 1) {
return Strategy().Strip(w, req, neq, utils.CopyURL(ss[rand.Intn(len(ss))].URL())), nil
}
srv, err := r.nextServer(w, req)
if err != nil {
return nil, err
}
return utils.CopyURL(srv.url), nil
return utils.CopyURL(srv.URL()), nil
}

func (r *RoundRobin) nextServer() (*server, error) {
func (r *RoundRobin) nextServer(w http.ResponseWriter, req *http.Request) (Server, error) {
r.mutex.Lock()
defer r.mutex.Unlock()

Expand Down Expand Up @@ -140,7 +145,7 @@ func (r *RoundRobin) nextServer() (*server, error) {
}
}
srv := r.servers[r.index]
if srv.weight >= r.currentWeight {
if srv.Weight() >= r.currentWeight {
return srv, nil
}
}
Expand All @@ -167,7 +172,7 @@ func (r *RoundRobin) Servers() []*url.URL {

out := make([]*url.URL, len(r.servers))
for i, srv := range r.servers {
out[i] = srv.url
out[i] = srv.URL()
}
return out
}
Expand All @@ -178,7 +183,7 @@ func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) {
defer r.mutex.Unlock()

if s, _ := r.findServerByURL(u); s != nil {
return s.weight, true
return s.Weight(), true
}
return -1, false
}
Expand Down Expand Up @@ -227,12 +232,12 @@ func (r *RoundRobin) resetState() {
r.resetIterator()
}

func (r *RoundRobin) findServerByURL(u *url.URL) (*server, int) {
func (r *RoundRobin) findServerByURL(u *url.URL) (Server, int) {
if len(r.servers) == 0 {
return nil, -1
}
for i, s := range r.servers {
if sameURL(u, s.url) {
if sameURL(u, s.URL()) {
return s, i
}
}
Expand All @@ -242,8 +247,8 @@ func (r *RoundRobin) findServerByURL(u *url.URL) (*server, int) {
func (r *RoundRobin) maxWeight() int {
max := -1
for _, s := range r.servers {
if s.weight > max {
max = s.weight
if s.Weight() > max {
max = s.Weight()
}
}
return max
Expand All @@ -253,9 +258,9 @@ func (r *RoundRobin) weightGcd() int {
divisor := -1
for _, s := range r.servers {
if divisor == -1 {
divisor = s.weight
divisor = s.Weight()
} else {
divisor = gcd(divisor, s.weight)
divisor = gcd(divisor, s.Weight())
}
}
return divisor
Expand All @@ -275,6 +280,18 @@ type server struct {
weight int
}

func (that *server) URL() *url.URL {
return that.url
}

func (that *server) Weight() int {
return that.weight
}

func (that *server) Set(weight int) {
that.weight = weight
}

var defaultWeight = 1

// SetDefaultWeight sets the default server weight.
Expand Down
80 changes: 80 additions & 0 deletions roundrobin/strategy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package roundrobin

import (
"net/http"
"net/url"
"sort"
)

func init() {
var _ LBStrategy = new(CompositeStrategy)
}
func Strategy() LBStrategy {
return strategies
}

func Provide(lbs LBStrategy) {
strategies.Add(lbs)
}

var strategies = new(CompositeStrategy)

type Server interface {

// URL server url.
URL() *url.URL

// Weight Relative weight for the endpoint to other endpoints in the load balancer.
Weight() int

// Set the weight.
Set(weight int)
}

type LBStrategy interface {

// Name is the strategy name.
Name() string

// Priority more than has more priority.
Priority() int

// Next servers
// Load balancer extension for custom rules filter.
Next(w http.ResponseWriter, req *http.Request, neq *http.Request, servers []Server) []Server

// Strip filter the server URL
Strip(w http.ResponseWriter, req *http.Request, neq *http.Request, uri *url.URL) *url.URL
}

type CompositeStrategy struct {
strategies []LBStrategy
}

func (that *CompositeStrategy) Add(lbs LBStrategy) *CompositeStrategy {
that.strategies = append(that.strategies, lbs)
sort.Slice(that.strategies, func(i, j int) bool { return that.strategies[i].Priority() < that.strategies[j].Priority() })
return that
}

func (that *CompositeStrategy) Name() string {
return "composite"
}

func (that *CompositeStrategy) Priority() int {
return 0
}

func (that *CompositeStrategy) Next(w http.ResponseWriter, req *http.Request, neq *http.Request, servers []Server) []Server {
for _, strategy := range that.strategies {
servers = strategy.Next(w, req, neq, servers)
}
return servers
}

func (that *CompositeStrategy) Strip(w http.ResponseWriter, req *http.Request, neq *http.Request, uri *url.URL) *url.URL {
for _, strategy := range that.strategies {
uri = strategy.Strip(w, req, neq, uri)
}
return uri
}