From 3688d5fd49d899f49773e1e5c7250d4cf3486d54 Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Sat, 30 Dec 2017 15:55:33 +0000 Subject: [PATCH] TokenAwarePolicy: use token replicas per placement strategy (#1039) * TokenAwarePolicy: use token replicas per placement strategy Add support for finding token replicas based on the placement stategy for each keyspace. Update host policies to recieve keyspace change updates and the ability to set the session once they are created. * rf is a string, parse it and provide helpful error messages * fix panic when loading ks meta * fix vet --- batch_test.go | 13 +-- cassandra_test.go | 1 + common_test.go | 1 + conn.go | 2 +- events.go | 24 ++++-- frame.go | 42 +++++++++- policies.go | 187 +++++++++++++++++++++++++++++------------ policies_test.go | 12 +-- query_executor.go | 1 + session.go | 52 +++++++----- token.go | 65 +++++++-------- token_test.go | 37 +++------ topology.go | 208 ++++++++++++++++++++++++++++++++++++++++++++++ topology_test.go | 163 ++++++++++++++++++++++++++++++++++++ 14 files changed, 652 insertions(+), 156 deletions(-) create mode 100644 topology.go create mode 100644 topology_test.go diff --git a/batch_test.go b/batch_test.go index 257ced7d2..0ebfe1d52 100644 --- a/batch_test.go +++ b/batch_test.go @@ -9,12 +9,15 @@ import ( func TestBatch_Errors(t *testing.T) { if *flagProto == 1 { - t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } session := createSession(t) defer session.Close() + if session.cfg.ProtoVersion < protoVersion2 { + t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") + } + if err := createTable(session, `CREATE TABLE gocql_test.batch_errors (id int primary key, val inet)`); err != nil { t.Fatal(err) } @@ -27,13 +30,13 @@ func TestBatch_Errors(t *testing.T) { } func TestBatch_WithTimestamp(t *testing.T) { - if *flagProto < protoVersion3 { - t.Skip("Batch timestamps are only available on protocol >= 3") - } - session := createSession(t) defer session.Close() + if session.cfg.ProtoVersion < protoVersion3 { + t.Skip("Batch timestamps are only available on protocol >= 3") + } + if err := createTable(session, `CREATE TABLE gocql_test.batch_ts (id int primary key, val text)`); err != nil { t.Fatal(err) } diff --git a/cassandra_test.go b/cassandra_test.go index b318bf734..6a292f915 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -83,6 +83,7 @@ func TestEmptyHosts(t *testing.T) { } func TestInvalidPeerEntry(t *testing.T) { + t.Skip("dont mutate system tables, rewrite this to test what we mean to test") session := createSession(t) // rack, release_version, schema_version, tokens are all null diff --git a/common_test.go b/common_test.go index 8f8a57f2e..d6711fade 100644 --- a/common_test.go +++ b/common_test.go @@ -94,6 +94,7 @@ func createCluster() *ClusterConfig { } func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) { + // TODO: tb.Helper() c := *cluster c.Keyspace = "system" c.Timeout = 30 * time.Second diff --git a/conn.go b/conn.go index 74b179d54..d16a70450 100644 --- a/conn.go +++ b/conn.go @@ -895,7 +895,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter { return iter case *resultKeyspaceFrame: return &Iter{framer: framer} - case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction: + case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: iter := &Iter{framer: framer} if err := c.awaitSchemaAgreement(); err != nil { // TODO: should have this behind a flag diff --git a/events.go b/events.go index e6d28a19b..73f5adc92 100644 --- a/events.go +++ b/events.go @@ -80,7 +80,6 @@ func (e *eventDebouncer) debounce(frame frame) { } func (s *Session) handleEvent(framer *framer) { - // TODO(zariel): need to debounce events frames, and possible also events defer framerPool.Put(framer) frame, err := framer.parseFrame() @@ -94,9 +93,10 @@ func (s *Session) handleEvent(framer *framer) { Logger.Printf("gocql: handling frame: %v\n", frame) } - // TODO: handle medatadata events switch f := frame.(type) { - case *schemaChangeKeyspace, *schemaChangeFunction, *schemaChangeTable: + case *schemaChangeKeyspace, *schemaChangeFunction, + *schemaChangeTable, *schemaChangeAggregate, *schemaChangeType: + s.schemaEvents.debounce(frame) case *topologyChangeEventFrame, *statusChangeEventFrame: s.nodeEvents.debounce(frame) @@ -106,22 +106,28 @@ func (s *Session) handleEvent(framer *framer) { } func (s *Session) handleSchemaEvent(frames []frame) { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.schemaDescriber == nil { - return - } + // TODO: debounce events for _, frame := range frames { switch f := frame.(type) { case *schemaChangeKeyspace: s.schemaDescriber.clearSchema(f.keyspace) + s.handleKeyspaceChange(f.keyspace, f.change) case *schemaChangeTable: s.schemaDescriber.clearSchema(f.keyspace) + case *schemaChangeAggregate: + s.schemaDescriber.clearSchema(f.keyspace) + case *schemaChangeFunction: + s.schemaDescriber.clearSchema(f.keyspace) + case *schemaChangeType: + s.schemaDescriber.clearSchema(f.keyspace) } } } +func (s *Session) handleKeyspaceChange(keyspace, change string) { + s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change}) +} + func (s *Session) handleNodeEvent(frames []frame) { type nodeEvent struct { change string diff --git a/frame.go b/frame.go index ae94c0f9f..504fe4264 100644 --- a/frame.go +++ b/frame.go @@ -1112,6 +1112,14 @@ func (f schemaChangeTable) String() string { return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object) } +type schemaChangeType struct { + frameHeader + + change string + keyspace string + object string +} + type schemaChangeFunction struct { frameHeader @@ -1121,6 +1129,15 @@ type schemaChangeFunction struct { args []string } +type schemaChangeAggregate struct { + frameHeader + + change string + keyspace string + name string + args []string +} + func (f *framer) parseResultSchemaChange() frame { if f.proto <= protoVersion2 { change := f.readString() @@ -1156,7 +1173,7 @@ func (f *framer) parseResultSchemaChange() frame { frame.keyspace = f.readString() return frame - case "TABLE", "TYPE": + case "TABLE": frame := &schemaChangeTable{ frameHeader: *f.header, change: change, @@ -1166,7 +1183,17 @@ func (f *framer) parseResultSchemaChange() frame { frame.object = f.readString() return frame - case "FUNCTION", "AGGREGATE": + case "TYPE": + frame := &schemaChangeType{ + frameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.object = f.readString() + + return frame + case "FUNCTION": frame := &schemaChangeFunction{ frameHeader: *f.header, change: change, @@ -1176,6 +1203,17 @@ func (f *framer) parseResultSchemaChange() frame { frame.name = f.readString() frame.args = f.readStringList() + return frame + case "AGGREGATE": + frame := &schemaChangeAggregate{ + frameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.name = f.readString() + frame.args = f.readStringList() + return frame default: panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) diff --git a/policies.go b/policies.go index 0001db118..19664319f 100644 --- a/policies.go +++ b/policies.go @@ -200,11 +200,18 @@ type HostStateNotifier interface { HostDown(host *HostInfo) } +type KeyspaceUpdateEvent struct { + Keyspace string + Change string +} + // HostSelectionPolicy is an interface for selecting // the most appropriate host to execute a given query. type HostSelectionPolicy interface { HostStateNotifier SetPartitioner + KeyspaceChanged(KeyspaceUpdateEvent) + Init(*Session) //Pick returns an iteration function over selected hosts Pick(ExecutableQuery) NextHost } @@ -239,9 +246,9 @@ type roundRobinHostPolicy struct { mu sync.RWMutex } -func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) { - // noop -} +func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {} +func (r *roundRobinHostPolicy) Init(*Session) {} func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost { // i is used to limit the number of attempts to find a host @@ -281,19 +288,69 @@ func (r *roundRobinHostPolicy) HostDown(host *HostInfo) { r.RemoveHost(host) } +func ShuffleReplicas() func(*tokenAwareHostPolicy) { + return func(t *tokenAwareHostPolicy) { + t.shuffleReplicas = true + } +} + // TokenAwareHostPolicy is a token aware host selection policy, where hosts are // selected based on the partition key, so queries are sent to the host which // owns the partition. Fallback is used when routing information is not available. -func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy { - return &tokenAwareHostPolicy{fallback: fallback} +func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAwareHostPolicy)) HostSelectionPolicy { + p := &tokenAwareHostPolicy{fallback: fallback} + for _, opt := range opts { + opt(p) + } + return p +} + +type keyspaceMeta struct { + replicas map[string]map[token][]*HostInfo } type tokenAwareHostPolicy struct { hosts cowHostList mu sync.RWMutex partitioner string - tokenRing *tokenRing fallback HostSelectionPolicy + session *Session + + tokenRing atomic.Value // *tokenRing + keyspaces atomic.Value // *keyspaceMeta + + shuffleReplicas bool +} + +func (t *tokenAwareHostPolicy) Init(s *Session) { + t.session = s +} + +func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) { + meta, _ := t.keyspaces.Load().(*keyspaceMeta) + // TODO: avoid recaulating things which havnt changed + newMeta := &keyspaceMeta{ + replicas: make(map[string]map[token][]*HostInfo, len(meta.replicas)), + } + + ks, err := t.session.KeyspaceMetadata(update.Keyspace) + if err == nil { + strat := getStrategy(ks) + tr := t.tokenRing.Load().(*tokenRing) + if tr != nil { + newMeta.replicas[update.Keyspace] = strat.replicaMap(t.hosts.get(), tr.tokens) + } + } + + if meta != nil { + for ks, replicas := range meta.replicas { + if ks != update.Keyspace { + newMeta.replicas[ks] = replicas + } + } + } + + t.keyspaces.Store(newMeta) } func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) { @@ -304,31 +361,34 @@ func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) { t.fallback.SetPartitioner(partitioner) t.partitioner = partitioner - t.resetTokenRing() + t.resetTokenRing(partitioner) } } func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) { - t.mu.Lock() - defer t.mu.Unlock() - t.hosts.add(host) t.fallback.AddHost(host) - t.resetTokenRing() + t.mu.RLock() + partitioner := t.partitioner + t.mu.RUnlock() + t.resetTokenRing(partitioner) } func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) { - t.mu.Lock() - defer t.mu.Unlock() - t.hosts.remove(host.ConnectAddress()) t.fallback.RemoveHost(host) - t.resetTokenRing() + t.mu.RLock() + partitioner := t.partitioner + t.mu.RUnlock() + t.resetTokenRing(partitioner) } func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) { + // TODO: need to avoid doing all the work on AddHost on hostup/down + // because it now expensive to calculate the replica map for each + // token t.AddHost(host) } @@ -336,22 +396,31 @@ func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) { t.RemoveHost(host) } -func (t *tokenAwareHostPolicy) resetTokenRing() { - if t.partitioner == "" { +func (t *tokenAwareHostPolicy) resetTokenRing(partitioner string) { + if partitioner == "" { // partitioner not yet set return } // create a new token ring hosts := t.hosts.get() - tokenRing, err := newTokenRing(t.partitioner, hosts) + tokenRing, err := newTokenRing(partitioner, hosts) if err != nil { Logger.Printf("Unable to update the token ring due to error: %s", err) return } // replace the token ring - t.tokenRing = tokenRing + t.tokenRing.Store(tokenRing) +} + +func (t *tokenAwareHostPolicy) getReplicas(keyspace string, token token) ([]*HostInfo, bool) { + meta, _ := t.keyspaces.Load().(*keyspaceMeta) + if meta == nil { + return nil, false + } + tokens, ok := meta.replicas[keyspace][token] + return tokens, ok } func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { @@ -362,45 +431,62 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { routingKey, err := qry.GetRoutingKey() if err != nil { return t.fallback.Pick(qry) + } else if routingKey == nil { + return t.fallback.Pick(qry) } - if routingKey == nil { + + tr, _ := t.tokenRing.Load().(*tokenRing) + if tr == nil { return t.fallback.Pick(qry) } - t.mu.RLock() - // TODO retrieve a list of hosts based on the replication strategy - host := t.tokenRing.GetHostForPartitionKey(routingKey) - t.mu.RUnlock() + token := tr.partitioner.Hash(routingKey) + primaryEndpoint := tr.GetHostForToken(token) - if host == nil { + if primaryEndpoint == nil || token == nil { return t.fallback.Pick(qry) } - // scope these variables for the same lifetime as the iterator function + replicas, ok := t.getReplicas(qry.Keyspace(), token) + if !ok { + replicas = []*HostInfo{primaryEndpoint} + } else if t.shuffleReplicas { + replicas = shuffleHosts(replicas) + } + var ( - hostReturned bool fallbackIter NextHost + i int ) + used := make(map[*HostInfo]bool) return func() SelectedHost { - if !hostReturned { - hostReturned = true - return (*selectedHost)(host) + for i < len(replicas) { + h := replicas[i] + i++ + + if !h.IsUp() { + // TODO: need a way to handle host distance, as we may want to not + // use hosts in specific DC's + continue + } + used[h] = true + + return (*selectedHost)(h) } - // fallback if fallbackIter == nil { + // fallback fallbackIter = t.fallback.Pick(qry) } - fallbackHost := fallbackIter() - // filter the token aware selected hosts from the fallback hosts - if fallbackHost != nil && fallbackHost.Info() == host { - fallbackHost = fallbackIter() + for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() { + if !used[fallbackHost.Info()] { + return fallbackHost + } } - - return fallbackHost + return nil } } @@ -428,6 +514,10 @@ type hostPoolHostPolicy struct { hostMap map[string]*HostInfo } +func (r *hostPoolHostPolicy) Init(*Session) {} +func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (r *hostPoolHostPolicy) SetPartitioner(string) {} + func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) { peers := make([]string, len(hosts)) hostMap := make(map[string]*HostInfo, len(hosts)) @@ -492,10 +582,6 @@ func (r *hostPoolHostPolicy) HostDown(host *HostInfo) { r.RemoveHost(host) } -func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) { - // noop -} - func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost { return func() SelectedHost { r.mu.RLock() @@ -557,11 +643,13 @@ type dcAwareRR struct { // return hosts which are in the local datacentre before returning hosts in all // other datercentres func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy { - return &dcAwareRR{ - local: localDC, - } + return &dcAwareRR{local: localDC} } +func (r *dcAwareRR) Init(*Session) {} +func (r *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (d *dcAwareRR) SetPartitioner(p string) {} + func (d *dcAwareRR) AddHost(host *HostInfo) { if host.DataCenter() == d.local { d.localHosts.add(host) @@ -578,15 +666,8 @@ func (d *dcAwareRR) RemoveHost(host *HostInfo) { } } -func (d *dcAwareRR) HostUp(host *HostInfo) { - d.AddHost(host) -} - -func (d *dcAwareRR) HostDown(host *HostInfo) { - d.RemoveHost(host) -} - -func (d *dcAwareRR) SetPartitioner(p string) {} +func (d *dcAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } +func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { var i int diff --git a/policies_test.go b/policies_test.go index ce2b1cfb6..bc89a64b0 100644 --- a/policies_test.go +++ b/policies_test.go @@ -14,7 +14,7 @@ import ( ) // Tests of the round-robin host selection policy implementation -func TestRoundRobinHostPolicy(t *testing.T) { +func TestHostPolicy_RoundRobin(t *testing.T) { policy := RoundRobinHostPolicy() hosts := [...]*HostInfo{ @@ -53,7 +53,7 @@ func TestRoundRobinHostPolicy(t *testing.T) { // Tests of the token-aware host selection policy implementation with a // round-robin host selection policy fallback. -func TestTokenAwareHostPolicy(t *testing.T) { +func TestHostPolicy_TokenAware(t *testing.T) { policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) query := &Query{} @@ -110,7 +110,7 @@ func TestTokenAwareHostPolicy(t *testing.T) { } // Tests of the host pool host selection policy implementation -func TestHostPoolHostPolicy(t *testing.T) { +func TestHostPolicy_HostPool(t *testing.T) { policy := HostPoolHostPolicy(hostpool.New(nil)) hosts := []*HostInfo{ @@ -150,7 +150,7 @@ func TestHostPoolHostPolicy(t *testing.T) { actualD.Mark(nil) } -func TestRoundRobinNilHostInfo(t *testing.T) { +func TestHostPolicy_RoundRobin_NilHostInfo(t *testing.T) { policy := RoundRobinHostPolicy() host := &HostInfo{hostId: "host-1"} @@ -175,7 +175,7 @@ func TestRoundRobinNilHostInfo(t *testing.T) { } } -func TestTokenAwareNilHostInfo(t *testing.T) { +func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) { policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) hosts := [...]*HostInfo{ @@ -302,7 +302,7 @@ func TestExponentialBackoffPolicy(t *testing.T) { } } -func TestDCAwareRR(t *testing.T) { +func TestHostPolicy_DCAwareRR(t *testing.T) { p := DCAwareRoundRobinPolicy("local") hosts := [...]*HostInfo{ diff --git a/query_executor.go b/query_executor.go index 4f9873016..89a02bb27 100644 --- a/query_executor.go +++ b/query_executor.go @@ -9,6 +9,7 @@ type ExecutableQuery interface { attempt(time.Duration) retryPolicy() RetryPolicy GetRoutingKey() ([]byte, error) + Keyspace() string RetryableQuery } diff --git a/session.go b/session.go index f08cbc528..918118ad6 100644 --- a/session.go +++ b/session.go @@ -112,14 +112,14 @@ func NewSession(cfg ClusterConfig) (*Session, error) { quit: make(chan struct{}), } + s.schemaDescriber = newSchemaDescriber(s) + s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent) s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent) s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) - s.hostSource = &ringDescriber{ - session: s, - } + s.hostSource = &ringDescriber{session: s} if cfg.PoolConfig.HostSelectionPolicy == nil { cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() @@ -127,6 +127,8 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.pool = cfg.PoolConfig.buildPool(s) s.policy = cfg.PoolConfig.HostSelectionPolicy + s.policy.Init(s) + s.executor = &queryExecutor{ pool: s.pool, policy: cfg.PoolConfig.HostSelectionPolicy, @@ -409,25 +411,15 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) { // fail fast if s.Closed() { return nil, ErrSessionClosed - } - - if keyspace == "" { + } else if keyspace == "" { return nil, ErrNoKeyspace } - s.mu.Lock() - // lazy-init schemaDescriber - if s.schemaDescriber == nil { - s.schemaDescriber = newSchemaDescriber(s) - } - s.mu.Unlock() - return s.schemaDescriber.getSchema(keyspace) } func (s *Session) getConn() *Conn { hosts := s.ring.allHosts() - var conn *Conn for _, host := range hosts { if !host.IsUp() { continue @@ -436,10 +428,7 @@ func (s *Session) getConn() *Conn { pool, ok := s.pool.getPool(host) if !ok { continue - } - - conn = pool.Pick() - if conn != nil { + } else if conn := pool.Pick(); conn != nil { return conn } } @@ -780,6 +769,16 @@ func (q *Query) retryPolicy() RetryPolicy { return q.rt } +// Keyspace returns the keyspace the query will be executed against. +func (q *Query) Keyspace() string { + if q.session == nil { + return "" + } + // TODO(chbannis): this should be parsed from the query or we should let + // this be set by users. + return q.session.cfg.Keyspace +} + // GetRoutingKey gets the routing key to use for routing this query. If // a routing key has not been explicitly set, then the routing key will // be constructed if possible using the keyspace's schema and the query @@ -1341,9 +1340,12 @@ type Batch struct { defaultTimestamp bool defaultTimestampValue int64 context context.Context + keyspace string } // NewBatch creates a new batch operation without defaults from the cluster +// +// Depreicated: use session.NewBatch instead func NewBatch(typ BatchType) *Batch { return &Batch{Type: typ} } @@ -1351,12 +1353,22 @@ func NewBatch(typ BatchType) *Batch { // NewBatch creates a new batch operation using defaults defined in the cluster func (s *Session) NewBatch(typ BatchType) *Batch { s.mu.RLock() - batch := &Batch{Type: typ, rt: s.cfg.RetryPolicy, serialCons: s.cfg.SerialConsistency, - Cons: s.cons, defaultTimestamp: s.cfg.DefaultTimestamp} + batch := &Batch{ + Type: typ, + rt: s.cfg.RetryPolicy, + serialCons: s.cfg.SerialConsistency, + Cons: s.cons, + defaultTimestamp: s.cfg.DefaultTimestamp, + keyspace: s.cfg.Keyspace, + } s.mu.RUnlock() return batch } +func (b *Batch) Keyspace() string { + return b.keyspace +} + // Attempts returns the number of attempts made to execute the batch. func (b *Batch) Attempts() int { return b.attempts diff --git a/token.go b/token.go index 5c3fce8d1..bdfcceb98 100644 --- a/token.go +++ b/token.go @@ -58,7 +58,7 @@ func (m murmur3Token) Less(token token) bool { // order preserving partitioner and token type orderedPartitioner struct{} -type orderedToken []byte +type orderedToken string func (p orderedPartitioner) Name() string { return "OrderedPartitioner" @@ -70,15 +70,15 @@ func (p orderedPartitioner) Hash(partitionKey []byte) token { } func (p orderedPartitioner) ParseString(str string) token { - return orderedToken([]byte(str)) + return orderedToken(str) } func (o orderedToken) String() string { - return string([]byte(o)) + return string(o) } func (o orderedToken) Less(token token) bool { - return -1 == bytes.Compare(o, token.(orderedToken)) + return o < token.(orderedToken) } // random partitioner and token @@ -118,18 +118,23 @@ func (r *randomToken) Less(token token) bool { return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken))) } +type hostToken struct { + token token + host *HostInfo +} + +func (ht hostToken) String() string { + return fmt.Sprintf("{token=%v host=%v}", ht.token, ht.host.HostID()) +} + // a data structure for organizing the relationship between tokens and hosts type tokenRing struct { partitioner partitioner - tokens []token - hosts []*HostInfo + tokens []hostToken } func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) { - tokenRing := &tokenRing{ - tokens: []token{}, - hosts: []*HostInfo{}, - } + tokenRing := &tokenRing{} if strings.HasSuffix(partitioner, "Murmur3Partitioner") { tokenRing.partitioner = murmur3Partitioner{} @@ -144,8 +149,7 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) { for _, host := range hosts { for _, strToken := range host.Tokens() { token := tokenRing.partitioner.ParseString(strToken) - tokenRing.tokens = append(tokenRing.tokens, token) - tokenRing.hosts = append(tokenRing.hosts, host) + tokenRing.tokens = append(tokenRing.tokens, hostToken{token, host}) } } @@ -159,16 +163,14 @@ func (t *tokenRing) Len() int { } func (t *tokenRing) Less(i, j int) bool { - return t.tokens[i].Less(t.tokens[j]) + return t.tokens[i].token.Less(t.tokens[j].token) } func (t *tokenRing) Swap(i, j int) { - t.tokens[i], t.hosts[i], t.tokens[j], t.hosts[j] = - t.tokens[j], t.hosts[j], t.tokens[i], t.hosts[i] + t.tokens[i], t.tokens[j] = t.tokens[j], t.tokens[i] } func (t *tokenRing) String() string { - buf := &bytes.Buffer{} buf.WriteString("TokenRing(") if t.partitioner != nil { @@ -176,15 +178,15 @@ func (t *tokenRing) String() string { } buf.WriteString("){") sep := "" - for i := range t.tokens { + for i, th := range t.tokens { buf.WriteString(sep) sep = "," buf.WriteString("\n\t[") buf.WriteString(strconv.Itoa(i)) buf.WriteString("]") - buf.WriteString(t.tokens[i].String()) + buf.WriteString(th.token.String()) buf.WriteString(":") - buf.WriteString(t.hosts[i].ConnectAddress().String()) + buf.WriteString(th.host.ConnectAddress().String()) } buf.WriteString("\n}") return string(buf.Bytes()) @@ -200,28 +202,19 @@ func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) *HostInfo { } func (t *tokenRing) GetHostForToken(token token) *HostInfo { - if t == nil { - return nil - } - - l := len(t.tokens) - // no host tokens, no available hosts - if l == 0 { + if t == nil || len(t.tokens) == 0 { return nil } // find the primary replica - ringIndex := sort.Search( - l, - func(i int) bool { - return !t.tokens[i].Less(token) - }, - ) - - if ringIndex == l { + ringIndex := sort.Search(len(t.tokens), func(i int) bool { + return !t.tokens[i].token.Less(token) + }) + + if ringIndex == len(t.tokens) { // wrap around to the first in the ring ringIndex = 0 } - host := t.hosts[ringIndex] - return host + + return t.tokens[ringIndex].host } diff --git a/token_test.go b/token_test.go index b71ff74cc..8646c9885 100644 --- a/token_test.go +++ b/token_test.go @@ -132,18 +132,13 @@ func TestRandomToken(t *testing.T) { type intToken int -func (i intToken) String() string { - return strconv.Itoa(int(i)) -} - -func (i intToken) Less(token token) bool { - return i < token.(intToken) -} +func (i intToken) String() string { return strconv.Itoa(int(i)) } +func (i intToken) Less(token token) bool { return i < token.(intToken) } // Test of the token ring implementation based on example at the start of this // page of documentation: // http://www.datastax.com/docs/0.8/cluster_architecture/partitioning -func TestIntTokenRing(t *testing.T) { +func TestTokenRing_Int(t *testing.T) { host0 := &HostInfo{} host25 := &HostInfo{} host50 := &HostInfo{} @@ -151,17 +146,11 @@ func TestIntTokenRing(t *testing.T) { ring := &tokenRing{ partitioner: nil, // these tokens and hosts are out of order to test sorting - tokens: []token{ - intToken(0), - intToken(50), - intToken(75), - intToken(25), - }, - hosts: []*HostInfo{ - host0, - host50, - host75, - host25, + tokens: []hostToken{ + {intToken(0), host0}, + {intToken(50), host50}, + {intToken(75), host75}, + {intToken(25), host25}, }, } @@ -209,7 +198,7 @@ func TestIntTokenRing(t *testing.T) { } // Test for the behavior of a nil pointer to tokenRing -func TestNilTokenRing(t *testing.T) { +func TestTokenRing_Nil(t *testing.T) { var ring *tokenRing = nil if ring.GetHostForToken(nil) != nil { @@ -221,7 +210,7 @@ func TestNilTokenRing(t *testing.T) { } // Test of the recognition of the partitioner class -func TestUnknownTokenRing(t *testing.T) { +func TestTokenRing_UnknownPartition(t *testing.T) { _, err := newTokenRing("UnknownPartitioner", nil) if err == nil { t.Error("Expected error for unknown partitioner value, but was nil") @@ -242,7 +231,7 @@ func hostsForTests(n int) []*HostInfo { } // Test of the tokenRing with the Murmur3Partitioner -func TestMurmur3TokenRing(t *testing.T) { +func TestTokenRing_Murmur3(t *testing.T) { // Note, strings are parsed directly to int64, they are not murmur3 hashed hosts := hostsForTests(4) ring, err := newTokenRing("Murmur3Partitioner", hosts) @@ -272,7 +261,7 @@ func TestMurmur3TokenRing(t *testing.T) { } // Test of the tokenRing with the OrderedPartitioner -func TestOrderedTokenRing(t *testing.T) { +func TestTokenRing_Ordered(t *testing.T) { // Tokens here more or less are similar layout to the int tokens above due // to each numeric character translating to a consistently offset byte. hosts := hostsForTests(4) @@ -304,7 +293,7 @@ func TestOrderedTokenRing(t *testing.T) { } // Test of the tokenRing with the RandomPartitioner -func TestRandomTokenRing(t *testing.T) { +func TestTokenRing_Random(t *testing.T) { // String tokens are parsed into big.Int in base 10 hosts := hostsForTests(4) ring, err := newTokenRing("RandomPartitioner", hosts) diff --git a/topology.go b/topology.go new file mode 100644 index 000000000..54ee4bcc1 --- /dev/null +++ b/topology.go @@ -0,0 +1,208 @@ +package gocql + +import ( + "fmt" + "strconv" + "strings" +) + +type placementStrategy interface { + replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo + replicationFactor(dc string) int +} + +func getReplicationFactorFromOpts(keyspace string, val interface{}) int { + // TODO: dont really want to panic here, but is better + // than spamming + switch v := val.(type) { + case int: + if v <= 0 { + panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace)) + } + return v + case string: + n, err := strconv.Atoi(v) + if err != nil { + panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err)) + } else if n <= 0 { + panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace)) + } + return n + default: + panic(fmt.Sprintf("unkown replication_factor type %T", v)) + } +} + +func getStrategy(ks *KeyspaceMetadata) placementStrategy { + switch { + case strings.Contains(ks.StrategyClass, "SimpleStrategy"): + return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])} + case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"): + dcs := make(map[string]int) + for dc, rf := range ks.StrategyOptions { + dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf) + } + return &networkTopology{dcs: dcs} + default: + // TODO: handle unknown replicas and just return the primary host for a token + panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass)) + } +} + +type simpleStrategy struct { + rf int +} + +func (s *simpleStrategy) replicationFactor(dc string) int { + return s.rf +} + +func (s *simpleStrategy) replicaMap(_ []*HostInfo, tokens []hostToken) map[token][]*HostInfo { + tokenRing := make(map[token][]*HostInfo, len(tokens)) + + for i, th := range tokens { + replicas := make([]*HostInfo, 0, s.rf) + for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ { + // TODO: need to ensure we dont add the same hosts twice + h := tokens[(i+j)%len(tokens)] + replicas = append(replicas, h.host) + } + tokenRing[th.token] = replicas + } + + return tokenRing +} + +type networkTopology struct { + dcs map[string]int +} + +func (n *networkTopology) replicationFactor(dc string) int { + return n.dcs[dc] +} + +func (n *networkTopology) haveRF(replicaCounts map[string]int) bool { + if len(replicaCounts) != len(n.dcs) { + return false + } + + for dc, rf := range n.dcs { + if rf != replicaCounts[dc] { + return false + } + } + + return true +} + +func (n *networkTopology) replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo { + dcRacks := make(map[string]map[string]struct{}) + + for _, h := range hosts { + dc := h.DataCenter() + rack := h.Rack() + + racks, ok := dcRacks[dc] + if !ok { + racks = make(map[string]struct{}) + dcRacks[dc] = racks + } + racks[rack] = struct{}{} + } + + tokenRing := make(map[token][]*HostInfo, len(tokens)) + + var totalRF int + for _, rf := range n.dcs { + totalRF += rf + } + + for i, th := range tokens { + // number of replicas per dc + // TODO: recycle these + replicasInDC := make(map[string]int, len(n.dcs)) + // dc -> racks + seenDCRacks := make(map[string]map[string]struct{}, len(n.dcs)) + // skipped hosts in a dc + skipped := make(map[string][]*HostInfo, len(n.dcs)) + + replicas := make([]*HostInfo, 0, totalRF) + for j := 0; j < len(tokens) && !n.haveRF(replicasInDC); j++ { + // TODO: ensure we dont add the same host twice + h := tokens[(i+j)%len(tokens)].host + + dc := h.DataCenter() + rack := h.Rack() + + rf, ok := n.dcs[dc] + if !ok { + // skip this DC, dont know about it + continue + } else if replicasInDC[dc] >= rf { + if replicasInDC[dc] > rf { + panic(fmt.Sprintf("replica overflow. rf=%d have=%d in dc %q", rf, replicasInDC[dc], dc)) + } + + // have enough replicas in this DC + continue + } else if _, ok := dcRacks[dc][rack]; !ok { + // dont know about this rack + continue + } else if len(replicas) >= totalRF { + if replicasInDC[dc] > rf { + panic(fmt.Sprintf("replica overflow. total rf=%d have=%d", totalRF, len(replicas))) + } + + // we now have enough replicas + break + } + + racks := seenDCRacks[dc] + if _, ok := racks[rack]; ok && len(racks) == len(dcRacks[dc]) { + // we have been through all the racks and dont have RF yet, add this + replicas = append(replicas, h) + replicasInDC[dc]++ + } else if !ok { + if racks == nil { + racks = make(map[string]struct{}, 1) + seenDCRacks[dc] = racks + } + + // new rack + racks[rack] = struct{}{} + replicas = append(replicas, h) + replicasInDC[dc]++ + + if len(racks) == len(dcRacks[dc]) { + // if we have been through all the racks, drain the rest of the skipped + // hosts until we have RF. The next iteration will skip in the block + // above + skippedHosts := skipped[dc] + var k int + for ; k < len(skippedHosts) && replicasInDC[dc] < rf; k++ { + sh := skippedHosts[k] + replicas = append(replicas, sh) + replicasInDC[dc]++ + } + skipped[dc] = skippedHosts[k:] + } + } else { + // already seen this rack, keep hold of this host incase + // we dont get enough for rf + skipped[dc] = append(skipped[dc], h) + } + } + + if len(replicas) == 0 || replicas[0] != th.host { + panic("first replica is not the primary replica for the token") + } + + tokenRing[th.token] = replicas + } + + if len(tokenRing) != len(tokens) { + panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(tokenRing), len(tokens))) + } + + return tokenRing +} diff --git a/topology_test.go b/topology_test.go new file mode 100644 index 000000000..23e4fe3dc --- /dev/null +++ b/topology_test.go @@ -0,0 +1,163 @@ +package gocql + +import ( + "fmt" + "sort" + "testing" +) + +func TestPlacementStrategy_SimpleStrategy(t *testing.T) { + host0 := &HostInfo{hostId: "0"} + host25 := &HostInfo{hostId: "25"} + host50 := &HostInfo{hostId: "50"} + host75 := &HostInfo{hostId: "75"} + + tokenRing := []hostToken{ + {intToken(0), host0}, + {intToken(25), host25}, + {intToken(50), host50}, + {intToken(75), host75}, + } + + hosts := []*HostInfo{host0, host25, host50, host75} + + strat := &simpleStrategy{rf: 2} + tokenReplicas := strat.replicaMap(hosts, tokenRing) + if len(tokenReplicas) != len(tokenRing) { + t.Fatalf("expected replica map to have %d items but has %d", len(tokenRing), len(tokenReplicas)) + } + + for token, replicas := range tokenReplicas { + if len(replicas) != strat.rf { + t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas), token) + } + } + + for i, token := range tokenRing { + replicas, ok := tokenReplicas[token.token] + if !ok { + t.Errorf("token %v not in replica map", token) + } + + for j, replica := range replicas { + exp := tokenRing[(i+j)%len(tokenRing)].host + if exp != replica { + t.Errorf("expected host %v to be a replica of %v got %v", exp, token, replica) + } + } + } +} + +func TestPlacementStrategy_NetworkStrategy(t *testing.T) { + var ( + hosts []*HostInfo + tokens []hostToken + ) + + const ( + totalDCs = 3 + racksPerDC = 3 + hostsPerDC = 5 + ) + + dcRing := make(map[string][]hostToken, totalDCs) + for i := 0; i < totalDCs; i++ { + var dcTokens []hostToken + dc := fmt.Sprintf("dc%d", i+1) + + for j := 0; j < hostsPerDC; j++ { + rack := fmt.Sprintf("rack%d", (j%racksPerDC)+1) + + h := &HostInfo{hostId: fmt.Sprintf("%s:%s:%d", dc, rack, j), dataCenter: dc, rack: rack} + + token := hostToken{ + token: orderedToken([]byte(h.hostId)), + host: h, + } + + tokens = append(tokens, token) + dcTokens = append(dcTokens, token) + + hosts = append(hosts, h) + } + + sort.Sort(&tokenRing{tokens: dcTokens}) + dcRing[dc] = dcTokens + } + + if len(tokens) != hostsPerDC*totalDCs { + t.Fatalf("expected %d tokens in the ring got %d", hostsPerDC*totalDCs, len(tokens)) + } + sort.Sort(&tokenRing{tokens: tokens}) + + strat := &networkTopology{ + dcs: map[string]int{ + "dc1": 1, + "dc2": 2, + "dc3": 3, + }, + } + + var expReplicas int + for _, rf := range strat.dcs { + expReplicas += rf + } + + tokenReplicas := strat.replicaMap(hosts, tokens) + if len(tokenReplicas) != len(tokens) { + t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas)) + } + + for token, replicas := range tokenReplicas { + if len(replicas) != expReplicas { + t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas), token) + } + } + + for dc, rf := range strat.dcs { + dcTokens := dcRing[dc] + for i, th := range dcTokens { + token := th.token + allReplicas, ok := tokenReplicas[token] + if !ok { + t.Fatalf("token %v not in replica map", token) + } + + var replicas []*HostInfo + for _, replica := range allReplicas { + if replica.dataCenter == dc { + replicas = append(replicas, replica) + } + } + + if len(replicas) != rf { + t.Fatalf("expected %d replicas in dc %q got %d", rf, dc, len(replicas)) + } + + var lastRack string + for j, replica := range replicas { + // expected is in the next rack + var exp *HostInfo + if lastRack == "" { + // primary, first replica + exp = dcTokens[(i+j)%len(dcTokens)].host + } else { + for k := 0; k < len(dcTokens); k++ { + // walk around the ring from i + j to find the next host the + // next rack + p := (i + j + k) % len(dcTokens) + h := dcTokens[p].host + if h.rack != lastRack { + exp = h + break + } + } + if exp.rack == lastRack { + panic("no more racks") + } + } + lastRack = replica.rack + } + } + } +}