diff --git a/virtual/activation_cache.go b/virtual/activation_cache.go index ef41392..4cccc0e 100644 --- a/virtual/activation_cache.go +++ b/virtual/activation_cache.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "runtime" + "strings" "sync" "time" @@ -83,16 +84,19 @@ func (a *activationsCache) ensureActivation( moduleID, actorID string, - blacklistedServerID string, + extraReplicas uint64, + blacklistedServerIDs []string, ) ([]types.ActorReference, error) { // Ensure we have a short timeout when communicating with registry. ctx, cc := context.WithTimeout(ctx, defaultActivationCacheTimeout) defer cc() + isServerIDBlacklisted := types.StringSliceToSet(blacklistedServerIDs) + if a.c == nil { // Cache disabled, load directly. return a.ensureActivationAndUpdateCache( - ctx, namespace, moduleID, actorID, nil, blacklistedServerID) + ctx, namespace, moduleID, actorID, extraReplicas, nil, isServerIDBlacklisted, blacklistedServerIDs) } var ( @@ -102,23 +106,56 @@ func (a *activationsCache) ensureActivation( bufIface, cacheKey = actorCacheKeyUnsafePooled(namespace, moduleID, actorID) aceI, ok := a.c.Get(cacheKey) bufPool.Put(bufIface) - // Cache miss, fill the cache. - if !ok || + + var ( + cachedReferences []types.ActorReference + nonBlacklistedCachedReferences []types.ActorReference + currentBlacklistedIDsAreInvalid = false + ) + + // Check if any of the servers the current request wants to blacklist are not marked as blacklisted in the cache. + // If any server is not blacklisted in the cache, it suggests that the cache entry might be stale and could potentially + // route us back to the blacklisted server ID. In such cases, the cache needs to be refreshed, and the existing entry should be ignored. + // + // Additionally, create a new slice `nonBlacklistedCachedReferences` that only includes the references from the cache + // that belong to non-blacklisted servers. This filtered slice will be used for subsequent processing. + if ok { + blacklistedIDsFromCache := aceI.(activationCacheEntry).blacklistedServerIDs + cachedReferences = aceI.(activationCacheEntry).references + + currentBlacklistedIDsAreInvalid = len(blacklistedIDsFromCache) != len(blacklistedServerIDs) + if !currentBlacklistedIDsAreInvalid { + for _, id := range blacklistedIDsFromCache { + if !isServerIDBlacklisted[id] { + currentBlacklistedIDsAreInvalid = true + break + } + } + } + + for _, ref := range cachedReferences { + if !isServerIDBlacklisted[ref.Physical.ServerID] { + nonBlacklistedCachedReferences = append(nonBlacklistedCachedReferences, ref) + } + } + } + + // Cache miss, not enough non-blacklisted replicas, or invalid blacklistedIDs list, then fill the cache. + // If there is a cache entry but it was satisfied by a request with a different blacklistedServerID, + // we must ignore the entry to avoid routing to a potentially stale blacklisted server. + // By forcing a cache update, we prevent routing to the blacklisted server ID and ensure fresh data. + if !ok || (1+extraReplicas) > uint64(len(nonBlacklistedCachedReferences)) || // There is an existing cache entry, however, it was satisfied by a request that did not provide // the same blacklistedServerID we have currently. We must ignore this entry because it could be // stale and end up routing us back to the blacklisted server ID. - (blacklistedServerID != "" && aceI.(activationCacheEntry).blacklistedServerID != blacklistedServerID) { - var cachedReferences []types.ActorReference - if ok { - cachedReferences = aceI.(activationCacheEntry).references - } + currentBlacklistedIDsAreInvalid { + // Force cache update and ignore the existing entry to prevent routing to blacklisted server ID. return a.ensureActivationAndUpdateCache( - ctx, namespace, moduleID, actorID, cachedReferences, blacklistedServerID) + ctx, namespace, moduleID, actorID, extraReplicas, cachedReferences, isServerIDBlacklisted, blacklistedServerIDs) } // Cache hit, return result from cache but check if we should proactively refresh // the cache also. - ace := aceI.(activationCacheEntry) // TODO: Jitter here. if time.Since(ace.cachedAt) > a.idealCacheStaleness { @@ -126,7 +163,7 @@ func (a *activationsCache) ensureActivation( go func() { defer cc() _, err := a.ensureActivationAndUpdateCache( - ctx, namespace, moduleID, actorID, ace.references, blacklistedServerID) + ctx, namespace, moduleID, actorID, extraReplicas, ace.references, isServerIDBlacklisted, blacklistedServerIDs) if err != nil { a.logger.Error( "error refreshing activation cache in background", @@ -135,7 +172,14 @@ func (a *activationsCache) ensureActivation( }() } - return ace.references, nil + return limit(nonBlacklistedCachedReferences, 1+extraReplicas), nil +} + +func limit(slice []types.ActorReference, min uint64) []types.ActorReference { + if len(slice) > int(min) { + return slice[:min] + } + return slice } func (a *activationsCache) delete( @@ -156,8 +200,10 @@ func (a *activationsCache) ensureActivationAndUpdateCache( moduleID, actorID string, + extraReplicas uint64, cachedReferences []types.ActorReference, - blacklistedServerID string, + isServerIDBlacklisted map[string]bool, + blacklistedServerIDs []string, ) ([]types.ActorReference, error) { // Since this method is less common (cache miss) we just allocate instead of messing // around with unsafe object pooling. @@ -166,11 +212,11 @@ func (a *activationsCache) ensureActivationAndUpdateCache( // Include blacklistedServerID in the dedupeKey so that "force refreshes" due to a // server blacklist / load-shedding an actor can be initiated *after* a regular // refresh has already started, but *before* it has completed. - dedupeKey := fmt.Sprintf("%s::%s", cacheKey, blacklistedServerID) + dedupeKey := fmt.Sprintf("%s::%s", cacheKey, strings.Join(blacklistedServerIDs, ",")) referencesI, err, _ := a.deduper.Do(dedupeKey, func() (any, error) { var cachedServerIDs []string for _, ref := range cachedReferences { - cachedServerIDs = append(cachedServerIDs, ref.ServerID()) + cachedServerIDs = append(cachedServerIDs, ref.Physical.ServerID) } // Acquire the semaphore before making the network call to avoid DDOSing the @@ -185,7 +231,8 @@ func (a *activationsCache) ensureActivationAndUpdateCache( ModuleID: moduleID, ActorID: actorID, - BlacklistedServerID: blacklistedServerID, + ExtraReplicas: extraReplicas, + BlacklistedServerIDs: blacklistedServerIDs, CachedActivationServerIDs: cachedServerIDs, }) // Release the semaphore as soon as we're done with the network call since the purpose @@ -209,10 +256,10 @@ func (a *activationsCache) ensureActivationAndUpdateCache( } for _, ref := range references.References { - if ref.ServerID() == blacklistedServerID { + if isServerIDBlacklisted[ref.Physical.ServerID] { return nil, fmt.Errorf( "[invariant violated] registry returned blacklisted server ID: %s in references", - blacklistedServerID) + ref.Physical.ServerID) } } @@ -225,7 +272,7 @@ func (a *activationsCache) ensureActivationAndUpdateCache( references: references.References, cachedAt: time.Now(), registryVersionStamp: references.VersionStamp, - blacklistedServerID: blacklistedServerID, + blacklistedServerIDs: blacklistedServerIDs, } // a.c is internally synchronized, but we use a lock here so we can do an atomic @@ -264,5 +311,5 @@ type activationCacheEntry struct { references []types.ActorReference cachedAt time.Time registryVersionStamp int64 - blacklistedServerID string + blacklistedServerIDs []string } diff --git a/virtual/activations.go b/virtual/activations.go index fd13f39..9b99138 100644 --- a/virtual/activations.go +++ b/virtual/activations.go @@ -118,13 +118,13 @@ func (a *activations) invoke( invokePayload []byte, isTimer bool, ) (io.ReadCloser, error) { - if err := a.isBlacklisted(reference); err != nil { + if err := a.isServerIDBlacklisted(reference); err != nil { return nil, err } // First check if the actor is already activated. a.Lock() - actorF, ok := a._actors[reference.ActorID()] + actorF, ok := a._actors[reference.ActorIDWithNamespace()] if !ok { if isTimer { // Timers should invoke already activated actors, but not instantiate them @@ -172,7 +172,7 @@ func (a *activations) invoke( // The actor is/was activated without error, but we still need to check the generation // count before we're allowed to invoke it. - if actor.reference().Generation() >= reference.Generation() { + if actor.reference().Generation >= reference.Generation { // The activated actor's generation count is high enough, we can just invoke now. return a.invokeActivatedActor(ctx, actor, operation, invokePayload) } @@ -184,7 +184,7 @@ func (a *activations) invoke( // Next, we check if the actor is still in the map (since we released and re-acquired // the lock, anything could have happened in the meantime). - actorF2, ok := a._actors[reference.ActorID()] + actorF2, ok := a._actors[reference.ActorIDWithNamespace()] if !ok { // Actor is no longer in the map. We can just proceed with a normal activation then. return a.invokeNotExistWithLock( @@ -205,7 +205,7 @@ func (a *activations) invoke( // The future has changed, the generation count should be high enough now and // we can just ignore the old actor (whichever Goroutine increased the generation // count will have closed it already) - if actor.reference().Generation() >= reference.Generation() { + if actor.reference().Generation >= reference.Generation { return a.invokeActivatedActor(ctx, actor, operation, invokePayload) } @@ -223,7 +223,7 @@ func (a *activations) invokeNotExistWithLock( prevActor *activatedActor, ) (io.ReadCloser, error) { fut := futures.New[*activatedActor]() - a._actors[reference.ActorID()] = fut + a._actors[reference.ActorIDWithNamespace()] = fut a.Unlock() // GoSync since this goroutine needs to wait anyways. @@ -232,7 +232,7 @@ func (a *activations) invokeNotExistWithLock( if err := prevActor.close(ctx); err != nil { a.log.Error("error closing previous instance of actor", slog.Any("actor", reference), slog.Any("error", err)) } - a._actorResourceTracker.track(reference.ActorID(), 0) + a._actorResourceTracker.track(reference.ActorIDWithNamespace(), 0) } defer func() { @@ -241,11 +241,11 @@ func (a *activations) invokeNotExistWithLock( // the future gets cleared from the map so that subsequent // invocations will try to recreate the actor instead of // receiving the same hard-coded over and over again. - delete(a._actors, reference.ActorID()) + delete(a._actors, reference.ActorIDWithNamespace()) } }() - module, err := a.ensureModule(ctx, reference.ModuleID()) + module, err := a.ensureModule(ctx, reference.ModuleIDWithNamespace()) if err != nil { return nil, fmt.Errorf( "error ensuring module for reference: %v, err: %w", @@ -258,19 +258,19 @@ func (a *activations) invokeNotExistWithLock( if err != nil { return nil, fmt.Errorf( "error instantiating actor: %s from module: %s, err: %w", - reference.ActorID(), reference.ModuleID(), err) + reference.ActorID, reference.ModuleID, err) } if err := assertActorIface(iActor); err != nil { return nil, fmt.Errorf( "error instantiating actor: %s from module: %s, err: %w", - reference.ActorID(), reference.ModuleID(), err) + reference.ActorID, reference.ModuleID, err) } onGc := func() { a.Lock() defer a.Unlock() - existing, ok := a._actors[reference.ActorID()] + existing, ok := a._actors[reference.ActorIDWithNamespace()] if !ok { // Actor has already been removed from the map, nothing else to do. return @@ -285,8 +285,8 @@ func (a *activations) invokeNotExistWithLock( // The actor is in the map and the future pointers match so we know its the same // instance of the actor that created this onGc function so we should remove it. - delete(a._actors, reference.ActorID()) - a._actorResourceTracker.track(reference.ActorID(), 0) + delete(a._actors, reference.ActorIDWithNamespace()) + a._actorResourceTracker.track(reference.ActorIDWithNamespace(), 0) } var currMemUsage int @@ -295,7 +295,7 @@ func (a *activations) invokeNotExistWithLock( if err != nil { return nil, fmt.Errorf("error activating actor: %w", err) } - a._actorResourceTracker.track(reference.ActorID(), currMemUsage) + a._actorResourceTracker.track(reference.ActorIDWithNamespace(), currMemUsage) return actor, nil }) @@ -318,7 +318,9 @@ func (a *activations) invokeActivatedActor( if err != nil { return nil, err } - a._actorResourceTracker.track(actor.reference().ActorID(), currMemUsage) + + ref := actor.reference() + a._actorResourceTracker.track(ref.ActorIDWithNamespace(), currMemUsage) return stream, nil } @@ -607,19 +609,19 @@ func (a *activations) close(ctx context.Context, numWorkers int) error { return nil } -func (a *activations) isBlacklisted( +func (a *activations) isServerIDBlacklisted( reference types.ActorReferenceVirtual, ) error { bufIface, cacheKey := actorCacheKeyUnsafePooled( - reference.Namespace(), reference.ModuleID().ID, reference.ActorID().ID) + reference.Namespace, reference.ModuleID, reference.ActorID) _, ok := a._blacklist.Get(cacheKey) // Immediately return to the pool cause we're done with it now regardless. bufPool.Put(bufIface) if ok { err := fmt.Errorf( - "actor %s is blacklisted on this server", reference.ActorID()) + "actor %s is blacklisted on this server", reference.ActorID) serverID, _ := a.getServerState() - return NewBlacklistedActivationError(err, serverID) + return NewBlacklistedActivationError(err, []string{serverID}) } return nil @@ -711,7 +713,7 @@ func (a *activatedActor) invoke( if a._closed { return 0, nil, fmt.Errorf( - "tried to invoke actor: %s which has already been closed", a._reference.ActorID()) + "tried to invoke actor: %s which has already been closed", a._reference.ActorID) } // Set a._lastInvoke to now so that if the timer function runs after we release the lock it will diff --git a/virtual/client.go b/virtual/client.go index 7e1835b..5b764eb 100644 --- a/virtual/client.go +++ b/virtual/client.go @@ -28,12 +28,12 @@ func (h *httpClient) InvokeActorRemote( ) (io.ReadCloser, error) { ir := invokeActorDirectRequest{ VersionStamp: versionStamp, - ServerID: reference.ServerID(), - ServerVersion: reference.ServerVersion(), - Namespace: reference.Namespace(), - ModuleID: reference.ModuleID().ID, - ActorID: reference.ActorID().ID, - Generation: reference.Generation(), + ServerID: reference.Physical.ServerID, + ServerVersion: reference.Physical.ServerVersion, + Namespace: reference.Virtual.Namespace, + ModuleID: reference.Virtual.ModuleID, + ActorID: reference.Virtual.ActorID, + Generation: reference.Virtual.Generation, Operation: operation, Payload: payload, CreateIfNotExist: create, @@ -45,7 +45,7 @@ func (h *httpClient) InvokeActorRemote( req, err := http.NewRequestWithContext( ctx, "POST", - fmt.Sprintf("http://%s/api/v1/invoke-actor-direct", reference.Address()), + fmt.Sprintf("http://%s/api/v1/invoke-actor-direct", reference.Physical.ServerState.Address), bytes.NewReader(marshaled)) if err != nil { return nil, fmt.Errorf("HTTPClient: InvokeDirect: error constructing request: %w", err) @@ -69,7 +69,7 @@ func (h *httpClient) InvokeActorRemote( // in statusCodeToErrorWrapper will be converted back to the proper in memory // error type if sent by a server to a client. if wrapper, ok := statusCodeToErrorWrapper[resp.StatusCode]; ok { - err = wrapper(err, reference.ServerID()) + err = wrapper(err, []string{reference.Physical.ServerID}) } return nil, err } @@ -125,5 +125,5 @@ func (n *noopClient) InvokeActorRemote( create types.CreateIfNotExist, ) (io.ReadCloser, error) { return nil, fmt.Errorf( - "noopClient: tried to invoke actor(%s) remotely using noop client. Instantiate Environment with a real client instead", reference.ActorID()) + "noopClient: tried to invoke actor(%s) remotely using noop client. Instantiate Environment with a real client instead", reference.Virtual.ActorID) } diff --git a/virtual/environment.go b/virtual/environment.go index 16b7dcb..e5e7472 100644 --- a/virtual/environment.go +++ b/virtual/environment.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "math/rand" "net" "runtime" "sync" @@ -34,6 +35,9 @@ var ( // Var so can be modified by tests. defaultActivationsCacheTTL = heartbeatTimeout DefaultGCActorsAfterDurationWithNoInvocations = time.Minute + + randEnv = rand.New(rand.NewSource(time.Now().UnixNano())) + muRand = sync.Mutex{} ) type environment struct { @@ -411,7 +415,7 @@ func (r *environment) InvokeActorStream( create types.CreateIfNotExist, ) (io.ReadCloser, error) { resp, err := r.invokeActorStreamHelper( - ctx, namespace, actorID, moduleID, operation, payload, create, "") + ctx, namespace, actorID, moduleID, operation, payload, create, nil) if err == nil { return resp, nil } @@ -424,15 +428,15 @@ func (r *environment) InvokeActorStream( // that actor to ensure we get an activation on a different serer since the registry // may not know about the blacklist yet. r.activationsCache.delete(namespace, moduleID, actorID) - blacklistedServerID := err.(BlacklistedActivationErr).ServerID() + blacklistedServerIDs := err.(BlacklistedActivationErr).ServerIDs() r.log.Warn( "encountered blacklisted actor, forcing activation cache refresh and retrying", slog.String("actor_id", fmt.Sprintf("%s::%s::%s", namespace, moduleID, actorID)), - slog.String("blacklisted_server_id", blacklistedServerID)) + slog.Any("blacklisted_server_ids", blacklistedServerIDs)) return r.invokeActorStreamHelper( - ctx, namespace, actorID, moduleID, operation, payload, create, blacklistedServerID) + ctx, namespace, actorID, moduleID, operation, payload, create, blacklistedServerIDs) } return nil, err @@ -446,7 +450,7 @@ func (r *environment) invokeActorStreamHelper( operation string, payload []byte, create types.CreateIfNotExist, - blacklistedServerID string, + blacklistedServerIDs []string, ) (io.ReadCloser, error) { if r.isClosed() { return nil, ErrEnvironmentClosed @@ -468,7 +472,7 @@ func (r *environment) invokeActorStreamHelper( } references, err := r.activationsCache.ensureActivation( - ctx, namespace, moduleID, actorID, blacklistedServerID) + ctx, namespace, moduleID, actorID, create.Options.ExtraReplicas, blacklistedServerIDs) if err != nil { return nil, fmt.Errorf("error ensuring actor activation: %w", err) } @@ -764,18 +768,20 @@ func (r *environment) invokeReferences( payload []byte, create types.CreateIfNotExist, ) (io.ReadCloser, error) { - // TODO: Load balancing or some other strategy if the number of references is > 1? - ref := references[0] + ref, ok := pickServerForInvocation(references, create) + if !ok { + return nil, errors.New("failed to pick server") + } if !r.opts.ForceRemoteProcedureCalls { // First check the global localEnvironmentsRouter map for scenarios where we're // potentially trying to communicate between multiple different in-memory // instances of Environment. localEnvironmentsRouterLock.RLock() - localEnv, ok := localEnvironmentsRouter[ref.Address()] + localEnv, ok := localEnvironmentsRouter[ref.Physical.ServerState.Address] localEnvironmentsRouterLock.RUnlock() if ok { return localEnv.InvokeActorDirectStream( - ctx, versionStamp, ref.ServerID(), ref.ServerVersion(), ref, + ctx, versionStamp, ref.Physical.ServerID, ref.Physical.ServerVersion, ref.Virtual, operation, payload, create) } @@ -795,9 +801,9 @@ func (r *environment) invokeReferences( // always return dnsregistry.Localhost as the address for all actor references and // thus ensure that tests can be written without having to also ensure that a NOLA // server is running on the appropriate port, among other things. - if ref.Address() == Localhost || ref.Address() == dnsregistry.Localhost { + if ref.Physical.ServerState.Address == Localhost || ref.Physical.ServerState.Address == dnsregistry.Localhost { return localEnv.InvokeActorDirectStream( - ctx, versionStamp, ref.ServerID(), ref.ServerVersion(), ref, + ctx, versionStamp, ref.Physical.ServerID, ref.Physical.ServerVersion, ref.Virtual, operation, payload, create) } } @@ -899,3 +905,14 @@ func formatActorCacheKey( dst = append(dst, []byte(actorID)...) return dst } + +func pickServerForInvocation(references []types.ActorReference, create types.CreateIfNotExist) (types.ActorReference, bool) { + // TODO: implement invokation strategies in 'create' e.g. region-based, multi-invoke... + if len(references) == 0 { + return types.ActorReference{}, false + } + + muRand.Lock() + defer muRand.Unlock() + return references[randEnv.Intn(len(references))], true +} diff --git a/virtual/environment_test.go b/virtual/environment_test.go index b629d56..d7e203f 100644 --- a/virtual/environment_test.go +++ b/virtual/environment_test.go @@ -696,6 +696,104 @@ func testHeartbeatAndRebalancingWithMemory( } } +// TestReplicationRandomGoModule tests the random replication logic of the environment. +// This test specifically examines how actors are replicated when the ExtraReplicas option is set to a value greater than 0. +// The test verifies that the actor is replicated across multiple environments in a random manner. +func TestReplicationRandomGoModule(t *testing.T) { + var ( + reg = localregistry.NewLocalRegistryWithOptions(registry.KVRegistryOptions{ + RebalanceMemoryThreshold: 1 << 24, + }) + moduleStore = newTestModuleStore() + ctx = context.Background() + ) + _, err := moduleStore.RegisterModule(ctx, "ns-1", "test-module", utilWasmBytes, registry.ModuleOptions{}) + require.NoError(t, err) + + // Create 3 environments backed by the same registry to simulate 3 different servers. Each environment + // needs its own port so it looks unique. + opts1 := defaultOptsWASM + opts1.Discovery.Port = 1 + env1, err := NewEnvironment(ctx, "serverID1", reg, moduleStore, nil, opts1) + require.NoError(t, err) + defer env1.Close(context.Background()) + + opts2 := defaultOptsWASM + opts2.Discovery.Port = 2 + env2, err := NewEnvironment(ctx, "serverID2", reg, moduleStore, nil, opts2) + require.NoError(t, err) + defer env2.Close(context.Background()) + + opts3 := defaultOptsWASM + opts3.Discovery.Port = 3 + env3, err := NewEnvironment(ctx, "serverID3", reg, moduleStore, nil, opts3) + require.NoError(t, err) + defer env3.Close(context.Background()) + + testReplicationRandom(t, env1, env2, env3) +} + +func TestReplicationRandomWASMModule(t *testing.T) { + var ( + reg = localregistry.NewLocalRegistryWithOptions(registry.KVRegistryOptions{ + RebalanceMemoryThreshold: 1 << 24, + }) + moduleStore = newTestModuleStore() + ctx = context.Background() + ) + _, err := moduleStore.RegisterModule(ctx, "ns-1", "test-module", utilWasmBytes, registry.ModuleOptions{}) + require.NoError(t, err) + + // Create 3 environments backed by the same registry to simulate 3 different servers. Each environment + // needs its own port so it looks unique. + opts1 := defaultOptsWASM + opts1.Discovery.Port = 1 + env1, err := NewEnvironment(ctx, "serverID1", reg, moduleStore, nil, opts1) + require.NoError(t, err) + defer env1.Close(context.Background()) + + opts2 := defaultOptsWASM + opts2.Discovery.Port = 2 + env2, err := NewEnvironment(ctx, "serverID2", reg, moduleStore, nil, opts2) + require.NoError(t, err) + defer env2.Close(context.Background()) + + opts3 := defaultOptsWASM + opts3.Discovery.Port = 3 + env3, err := NewEnvironment(ctx, "serverID3", reg, moduleStore, nil, opts3) + require.NoError(t, err) + defer env3.Close(context.Background()) + + testReplicationRandom(t, env1, env2, env3) +} + +// testReplicationRandom is a test function that verifies the random replication of actors across multiple environments. +// +// The test logic is as follows: +// - Invoke an actor with the ExtraReplicas option set to 2. +// - Continuously check if the actor has been replicated in all three environments. +// - If the actor is not activated in all three environments, invoke the actor again. +// - Repeat the check until the actor is activated in all three environments or until a certain timeout is reached. +// - If the actor is not activated in all three environments within the specified time, the test fails. +func testReplicationRandom( + t *testing.T, + env1, env2, env3 Environment, +) { + ctx := context.Background() + + require.Eventually(t, func() bool { + numActivatedActors := env1.NumActivatedActors() + env2.NumActivatedActors() + env3.NumActivatedActors() + if numActivatedActors == 3 { + return true + } + + _, err := env1.InvokeActor(ctx, "ns-1", "actor-0", "test-module", "inc", nil, types.CreateIfNotExist{Options: types.ActorOptions{ExtraReplicas: 2}}) + require.NoError(t, err) + + return false + }, time.Minute, time.Microsecond, "actor is not replicated") +} + // TestVersionStampIsHonored ensures that the interaction between the client and server // around versionstamp coordination works by preventing the server from updating its // internal versionstamp and ensuring that eventually RPCs start to fail because the diff --git a/virtual/errs.go b/virtual/errs.go index 66ba86b..48853de 100644 --- a/virtual/errs.go +++ b/virtual/errs.go @@ -7,12 +7,12 @@ import ( ) var ( - statusCodeToErrorWrapper = map[int]func(err error, serverID string) error{ + statusCodeToErrorWrapper = map[int]func(err error, serverID []string) error{ 410: NewBlacklistedActivationError, } // Make sure it implements interface. - _ HTTPError = NewBlacklistedActivationError(errors.New("n/a"), "n/a").(HTTPError) + _ HTTPError = NewBlacklistedActivationError(errors.New("n/a"), []string{"n/a"}).(HTTPError) ) // HTTPError is the interface implemented by errors that map to a specific @@ -28,22 +28,22 @@ type HTTPError interface { // blacklisted on this specific server temporarily (usually due to resource // usage or balancing reasons). type BlacklistedActivationErr struct { - err error - serverID string + err error + serverIDs []string } // NewBlacklistedActivationError creates a new BlacklistedActivationErr. -func NewBlacklistedActivationError(err error, serverID string) error { - if serverID == "" { +func NewBlacklistedActivationError(err error, serverIDs []string) error { + if len(serverIDs) <= 0 { panic("[invariant violated] serverID cannot be empty") } - return BlacklistedActivationErr{err: err, serverID: serverID} + return BlacklistedActivationErr{err: err, serverIDs: serverIDs} } func (b BlacklistedActivationErr) Error() string { return fmt.Sprintf( "BlacklistedActivationError(ServerID:%s): %s", - b.serverID, b.err.Error()) + b.serverIDs, b.err.Error()) } func (b BlacklistedActivationErr) Is(target error) bool { @@ -60,8 +60,8 @@ func (b BlacklistedActivationErr) HTTPStatusCode() int { return http.StatusGone } -func (b BlacklistedActivationErr) ServerID() string { - return b.serverID +func (b BlacklistedActivationErr) ServerIDs() []string { + return b.serverIDs } // IsBlacklistedActivationError returns a boolean indicating whether the error diff --git a/virtual/errs_test.go b/virtual/errs_test.go index fa23775..b7db17b 100644 --- a/virtual/errs_test.go +++ b/virtual/errs_test.go @@ -13,14 +13,14 @@ func TestBlacklistedActivationError(t *testing.T) { require.False(t, errors.Is(errors.New("random"), BlacklistedActivationErr{})) require.False(t, IsBlacklistedActivationError(errors.New("random"))) - require.True(t, errors.Is(NewBlacklistedActivationError(errors.New("random"), "abc"), &BlacklistedActivationErr{})) - require.True(t, errors.Is(NewBlacklistedActivationError(errors.New("random"), "abc"), BlacklistedActivationErr{})) - require.True(t, IsBlacklistedActivationError(NewBlacklistedActivationError(errors.New("random"), "abc"))) - require.True(t, IsBlacklistedActivationError(fmt.Errorf("wrapped: %w", NewBlacklistedActivationError(errors.New("random"), "abc")))) + require.True(t, errors.Is(NewBlacklistedActivationError(errors.New("random"), []string{"abc"}), &BlacklistedActivationErr{})) + require.True(t, errors.Is(NewBlacklistedActivationError(errors.New("random"), []string{"abc"}), BlacklistedActivationErr{})) + require.True(t, IsBlacklistedActivationError(NewBlacklistedActivationError(errors.New("random"), []string{"abc"}))) + require.True(t, IsBlacklistedActivationError(fmt.Errorf("wrapped: %w", NewBlacklistedActivationError(errors.New("random"), []string{"abc"})))) - require.Equal(t, "abc", NewBlacklistedActivationError(errors.New("random"), "abc").(BlacklistedActivationErr).ServerID()) + require.Contains(t, NewBlacklistedActivationError(errors.New("random"), []string{"abc"}).(BlacklistedActivationErr).ServerIDs(), "abc") var httpErr HTTPError require.True(t, errors.As( - NewBlacklistedActivationError(errors.New("random"), "abc"), &httpErr)) + NewBlacklistedActivationError(errors.New("random"), []string{"abc"}), &httpErr)) } diff --git a/virtual/host_capabilities.go b/virtual/host_capabilities.go index 1f082ef..28604cd 100644 --- a/virtual/host_capabilities.go +++ b/virtual/host_capabilities.go @@ -47,7 +47,7 @@ func (h *hostCapabilities) InvokeActor( req types.InvokeActorRequest, ) ([]byte, error) { return h.env.InvokeActor( - ctx, h.reference.Namespace(), req.ActorID, req.ModuleID, + ctx, h.reference.Namespace, req.ActorID, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) } @@ -98,5 +98,5 @@ func (h *hostCapabilities) CustomFn( } return nil, fmt.Errorf( "unknown host function: %s::%s::%s", - h.reference.Namespace(), operation, payload) + h.reference.Namespace, operation, payload) } diff --git a/virtual/registry/dnsregistry/dns_registry.go b/virtual/registry/dnsregistry/dns_registry.go index 807234f..185e382 100644 --- a/virtual/registry/dnsregistry/dns_registry.go +++ b/virtual/registry/dnsregistry/dns_registry.go @@ -133,8 +133,8 @@ func (d *dnsRegistry) EnsureActivation( serverIP := ring.Get(fmt.Sprintf("%s::%s", req.ActorID, req.ModuleID)) ref, err := types.NewActorReference( - DNSServerID, DNSServerVersion, serverIP, req.Namespace, - req.ModuleID, req.ActorID, DNS_ACTOR_GENERATION) + DNSServerID, DNSServerVersion, req.Namespace, + req.ModuleID, req.ActorID, DNS_ACTOR_GENERATION, types.ServerState{Address: serverIP}) if err != nil { return registry.EnsureActivationResult{}, fmt.Errorf( "error creating actor reference: %w", err) diff --git a/virtual/registry/dnsregistry/dns_registry_test.go b/virtual/registry/dnsregistry/dns_registry_test.go index 3d10677..83ef448 100644 --- a/virtual/registry/dnsregistry/dns_registry_test.go +++ b/virtual/registry/dnsregistry/dns_registry_test.go @@ -62,11 +62,11 @@ func TestDNSRegistrySimple(t *testing.T) { } require.Equal(t, 1, len(activations.References)) - require.Equal(t, "a", activations.References[0].ActorID().ID) - require.Equal(t, "test-module", activations.References[0].ModuleID().ID) - require.Equal(t, "127.0.0.3:9090", activations.References[0].Address()) - require.Equal(t, DNSServerID, activations.References[0].ServerID()) - require.Equal(t, DNSServerVersion, activations.References[0].ServerVersion()) + require.Equal(t, "a", activations.References[0].Virtual.ActorID) + require.Equal(t, "test-module", activations.References[0].Virtual.ModuleID) + require.Equal(t, "127.0.0.3:9090", activations.References[0].Physical.ServerState.Address) + require.Equal(t, DNSServerID, activations.References[0].Physical.ServerID) + require.Equal(t, DNSServerVersion, activations.References[0].Physical.ServerVersion) break } } @@ -94,9 +94,9 @@ func TestDNSRegistrySingleNode(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, len(activations.References)) - require.Equal(t, "a", activations.References[0].ActorID().ID) - require.Equal(t, "test-module", activations.References[0].ModuleID().ID) - require.Equal(t, "127.0.0.1:9090", activations.References[0].Address()) - require.Equal(t, DNSServerID, activations.References[0].ServerID()) - require.Equal(t, DNSServerVersion, activations.References[0].ServerVersion()) + require.Equal(t, "a", activations.References[0].Virtual.ActorID) + require.Equal(t, "test-module", activations.References[0].Virtual.ModuleID) + require.Equal(t, "127.0.0.1:9090", activations.References[0].Physical.ServerState.Address) + require.Equal(t, DNSServerID, activations.References[0].Physical.ServerID) + require.Equal(t, DNSServerVersion, activations.References[0].Physical.ServerVersion) } diff --git a/virtual/registry/kv_registry.go b/virtual/registry/kv_registry.go index 063d010..6e33209 100644 --- a/virtual/registry/kv_registry.go +++ b/virtual/registry/kv_registry.go @@ -221,150 +221,117 @@ func (k *kvRegistry) EnsureActivation( req EnsureActivationRequest, ) (EnsureActivationResult, error) { actorKey := getActorKey(req.Namespace, req.ActorID, req.ModuleID) - references, err := k.kv.Transact(func(tr kv.Transaction) (any, error) { - ra, ok, err := k.getActor(ctx, tr, actorKey) - if err == nil && !ok { - _, err := k.createActor( - ctx, tr, req.Namespace, req.ActorID, req.ModuleID, types.ActorOptions{}) - if err != nil { - return nil, fmt.Errorf("EnsureActivation: error creating actor: %w", err) - } - ra, ok, err = k.getActor(ctx, tr, actorKey) - if err != nil { - return nil, fmt.Errorf("EnsureActivation: error getting actor: %w", err) - } - if !ok { - return nil, fmt.Errorf( - "[invariant violated] error ensuring activation of actor with ID: %s, does not exist in namespace: %s, err: %w", - req.ActorID, req.Namespace, errActorDoesNotExist) - } - } - if err != nil { - return nil, fmt.Errorf("EnsureActivation: error getting actor: %w", err) - } - if !ok { - // Make sure we use %w to wrap the errActorDoesNotExist so the caller can use - // errors.Is() on it. - return nil, fmt.Errorf( - "[invariant violated] error ensuring activation of actor with ID: %s, does not exist in namespace: %s, err: %w", - req.ActorID, req.Namespace, errActorDoesNotExist) - } - serverKey := getServerKey(ra.Activation.ServerID) - v, ok, err := tr.Get(ctx, serverKey) + // Perform a transaction to ensure atomicity. + references, err := k.kv.Transact(func(tr kv.Transaction) (any, error) { + // First, check if the actor exists already, and if not create it. + ra, err := k.getOrCreateActor(ctx, req, actorKey, tr) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get/create actor: %w", err) } - var ( - server serverState - serverExists bool - ) - if ok { - if err := json.Unmarshal(v, &server); err != nil { - return nil, fmt.Errorf("error unmarsaling server state with ID: %s", req.ActorID) - } - serverExists = true - } + // Next we try to get unblacklisted servers where the actor is currently running. + // Because we don't activate an actor in a new server unless there are not enough replicas + // Get the version stamp for validations of Heartbeat TTLs. vs, err := tr.GetVersionStamp() if err != nil { return nil, fmt.Errorf("error getting versionstamp: %w", err) } + // Convert blacklisted server IDs to a set for efficient lookup. + isServerIDBlacklisted := types.StringSliceToSet(req.BlacklistedServerIDs) - var ( - currActivation, activationExists = ra.Activation, ra.Activation.ServerID != "" - timeSinceLastHeartbeat = versionSince(vs, server.LastHeartbeatedAt) - serverID string - serverAddress string - serverVersion int64 - ) - if activationExists && - serverExists && - timeSinceLastHeartbeat < HeartbeatTTL && - currActivation.ServerID != req.BlacklistedServerID { - // We have an existing activation and the server is still alive, so just use that. - - // It is acceptable to look up the ServerVersion from the server discovery key directly, - // as long as the activation is still active, it guarantees that the server's version - // has not changed since the activation was first created. - serverVersion = server.ServerVersion - serverID = currActivation.ServerID - serverAddress = server.HeartbeatState.Address - } else { - // We need to create a new activation because either: - // 1. There is no activation OR - // 2. The server the actor is currently activated on has stopped heartbeating OR - // 3. The server the actor is currently activated on has blacklisted this actor (most - // likely for balancing reasons) - liveServers, err := getLiveServers(ctx, vs, tr) - if err != nil { - return nil, err - } - if len(liveServers) == 0 { - return nil, fmt.Errorf("0 live servers available for new activation") - } + // Get servers from currently running actor activations + refs, activations, err := k.getExistingUnblacklistedActivations(ctx, tr, req, isServerIDBlacklisted, vs, ra) + if err != nil { + return nil, fmt.Errorf("failed getting existing unblacklisted references from the kv store: %w", err) + } - maxNumHeartbeats := 0 - for _, s := range liveServers { - if s.NumHeartbeats > maxNumHeartbeats { - maxNumHeartbeats = s.NumHeartbeats - } - } - if maxNumHeartbeats < k.opts.MinSuccessiveHeartbeatsBeforeAllowActivations { - return nil, fmt.Errorf( - "maxNumHeartbeats: %d < MinSuccessiveHeartbeatsBeforeAllowActivations(%d)", - maxNumHeartbeats, k.opts.MinSuccessiveHeartbeatsBeforeAllowActivations) - } + // We reset activations because the activations slice may have been filtered to + // exclude activations associated with blacklisted server IDs. + ra.Activations = activations - // TODO: Update this code once we support configurable replication. - var cachedServerID string - if len(req.CachedActivationServerIDs) > 0 { - cachedServerID = req.CachedActivationServerIDs[0] - } + // If we already have enough replicas, return the references. + if uint64(len(refs)) >= 1+req.ExtraReplicas { + return EnsureActivationResult{ + References: refs, + VersionStamp: vs, + }, nil + } - selected, selectionReason := pickServerForActivation( - liveServers, k.opts, req.BlacklistedServerID, cachedServerID, !activationExists) - serverID = selected.ServerID - serverAddress = selected.HeartbeatState.Address - serverVersion = selected.ServerVersion - currActivation = newActivation(serverID, serverVersion) + // We need to create a new activation because we don't have the desired number of replicas. + // This can happen in the following scenarios: + // 1. There is no existing activation for the actor. + // 2. One or more of the servers where the actor is currently activated has stopped heartbeating. + // 3. One or more of the servers where the actor is currently activated has blacklisted the actor, typically for load balancing purposes. - ra.Activation = currActivation - marshaled, err := json.Marshal(&ra) - if err != nil { - return nil, fmt.Errorf("error marshaling activation: %w", err) - } + // First to see where the new replicas should be activated we need to get a list of all available servers. + liveServers, err := getLiveServers(ctx, vs, tr) + if err != nil { + return fmt.Errorf("failed to get live servers: %w", err), err + } + if len(liveServers) == 0 { + return nil, fmt.Errorf("0 live servers available for new activation") + } - tr.Put(ctx, actorKey, marshaled) + // Then, we find the maximum number of heartbeats among live servers, to ensure there are no stale entries. + maxNumHeartbeats := findMaxNumHeartbeats(liveServers) - if !k.opts.DisableHighConflictOperations { - selected.HeartbeatState.NumActivatedActors++ - marshaled, err := json.Marshal(&selected) - if err != nil { - return nil, fmt.Errorf("error marshaling server state: %w", err) - } + // Check if the maximum number of heartbeats satisfies the required threshold. + if maxNumHeartbeats < k.opts.MinSuccessiveHeartbeatsBeforeAllowActivations { + return nil, fmt.Errorf( + "maxNumHeartbeats: %d < MinSuccessiveHeartbeatsBeforeAllowActivations(%d)", + maxNumHeartbeats, k.opts.MinSuccessiveHeartbeatsBeforeAllowActivations) + } + + // We create a set with the unblacklisted existing servers, + // to avoid filling the selection with servers that are already selected + isActivatedOnServer := make(map[string]bool, len(activations)) + for _, ref := range refs { + isActivatedOnServer[ref.Physical.ServerID] = true + } + // Pick the remaining servers needed to comply with the replication criteria. + selected, selectionReason := pickServersForActivation( + (1+req.ExtraReplicas)-uint64(len(refs)), + liveServers, + k.opts, + isServerIDBlacklisted, + req.CachedActivationServerIDs, + isActivatedOnServer, + ) - tr.Put(ctx, getServerKey(selected.ServerID), marshaled) + // For every select server, updates the required information to reflect the activation, + // creates a reference for it, and adds it to the 'refs' result. + for _, server := range selected { + if err := k.activateActor(ctx, tr, server, &ra); err != nil { + return nil, fmt.Errorf("failed activating actor: %w", err) } - k.opts.Logger.Info( "activated actor on server", slog.String("actor_id", fmt.Sprintf("%s::%s:%s", req.Namespace, req.ModuleID, req.ActorID)), - slog.String("server_id", selected.ServerID), - slog.String("server_address", selected.HeartbeatState.Address), + slog.String("server_id", server.ServerID), + slog.String("server_address", server.HeartbeatState.Address), slog.String("selection_reason", selectionReason), ) + + ref, err := types.NewActorReference( + server.ServerID, server.ServerVersion, req.Namespace, ra.ModuleID, req.ActorID, ra.Generation, types.ServerState{Address: server.HeartbeatState.Address}) + if err != nil { + return nil, fmt.Errorf("error creating new actor reference: %w", err) + } + + refs = append(refs, ref) } - ref, err := types.NewActorReference( - serverID, serverVersion, serverAddress, req.Namespace, ra.ModuleID, req.ActorID, ra.Generation) + // Store the newly updated actor, to reflect the latest changes of its activations. + marshaled, err := json.Marshal(&ra) if err != nil { - return nil, fmt.Errorf("error creating new actor reference: %w", err) + return nil, fmt.Errorf("error marshaling activation: %w", err) } + tr.Put(ctx, actorKey, marshaled) return EnsureActivationResult{ - References: []types.ActorReference{ref}, + References: refs, VersionStamp: vs, }, nil }) @@ -375,6 +342,141 @@ func (k *kvRegistry) EnsureActivation( return references.(EnsureActivationResult), nil } +// getOrCreateActor retrieves an existing actor from the registry or creates a new one if it doesn't exist. +// After creating the actor, it attempts to get the actor again to ensure its existence. +func (k *kvRegistry) getOrCreateActor( + ctx context.Context, + req EnsureActivationRequest, + actorKey []byte, + tr kv.Transaction, +) (registeredActor, error) { + ra, ok, err := k.getActor(ctx, tr, actorKey) + if err == nil && !ok { + _, err := k.createActor( + ctx, tr, req.Namespace, req.ActorID, req.ModuleID, types.ActorOptions{}) + if err != nil { + return registeredActor{}, fmt.Errorf("EnsureActivation: error creating actor: %w", err) + } + ra, ok, err = k.getActor(ctx, tr, actorKey) + if err != nil { + return registeredActor{}, fmt.Errorf("EnsureActivation: error getting actor: %w", err) + } + if !ok { + return registeredActor{}, fmt.Errorf( + "[invariant violated] error ensuring activation of actor with ID: %s, does not exist in namespace: %s, err: %w", + req.ActorID, req.Namespace, errActorDoesNotExist) + } + } + if err != nil { + return registeredActor{}, fmt.Errorf("EnsureActivation: error getting actor: %w", err) + } + if !ok { + // Make sure we use %w to wrap the errActorDoesNotExist so the caller can use + // errors.Is() on it. + return registeredActor{}, fmt.Errorf( + "[invariant violated] error ensuring activation of actor with ID: %s, does not exist in namespace: %s, err: %w", + req.ActorID, req.Namespace, errActorDoesNotExist) + } + + return ra, nil +} + +// getExistingUnblacklistedActivations retrieves existing unblacklisted activations for a given actor from the registry. +// It iterates through the current activations and converts them into actor references until the desired number of replicas is achieved. +// For each activation, it checks if the server is still alive and within the heartbeat TTL. +func (k *kvRegistry) getExistingUnblacklistedActivations( + ctx context.Context, + tr kv.Transaction, + req EnsureActivationRequest, + isServerIDBlacklisted map[string]bool, + vs int64, + ra registeredActor, +) ([]types.ActorReference, []activation, error) { + var ( + activations = ra.Activations + refs = make([]types.ActorReference, 0, len(activations)) + validActivactions = make([]activation, 0, len(activations)) + ) + // Iterate through the current activations and convert into references, until replicas is already achieved. + for _, a := range activations { + if uint64(len(refs)) >= 1+req.ExtraReplicas { + break + } + + if isServerIDBlacklisted[a.ServerID] { + continue + } + + serverKey := getServerKey(a.ServerID) + v, ok, err := tr.Get(ctx, serverKey) + if err != nil { + return nil, nil, err + } + + if !ok { + // Server doesn't exist so this activation is invalid. + continue + } + + // Server exists. + + var server serverState + if err := json.Unmarshal(v, &server); err != nil { + return nil, nil, fmt.Errorf("error unmarsaling server state with ID: %s", req.ActorID) + } + if versionSince(vs, server.LastHeartbeatedAt) > HeartbeatTTL { + // Server "exists" but has not heartbeated recently. Assume its dead and ignore this activation. + continue + } + + // We have an existing activation and the server is still alive, so just use that. + // It is acceptable to look up the ServerVersion from the server discovery key directly, + // as long as the activation is still active, it guarantees that the server's version + // has not changed since the activation was first created. + ref, err := types.NewActorReference( + server.ServerID, server.ServerVersion, req.Namespace, ra.ModuleID, req.ActorID, ra.Generation, types.ServerState{Address: server.HeartbeatState.Address}) + if err != nil { + return nil, nil, fmt.Errorf("error creating new actor reference: %w", err) + } + + // Activation is assigned to an active, non-blacklisted server ID list so we can use it. + refs = append(refs, ref) + validActivactions = append(validActivactions, a) + } + return refs, validActivactions, nil +} + +func findMaxNumHeartbeats(servers []serverState) int { + maxNumHeartbeats := 0 + for _, s := range servers { + if s.NumHeartbeats > maxNumHeartbeats { + maxNumHeartbeats = s.NumHeartbeats + } + } + + return maxNumHeartbeats +} + +// activateActor updates the registry to indicate that a server has a newly activated actor. +// This function is responsible for updating the necessary information in the registry to reflect the activation +// of an actor on a specific server. +func (k *kvRegistry) activateActor(ctx context.Context, tr kv.Transaction, server serverState, ra *registeredActor) error { + a := newActivation(server.ServerID, server.ServerVersion) + ra.Activations = append(ra.Activations, a) + + if !k.opts.DisableHighConflictOperations { + server.HeartbeatState.NumActivatedActors++ + marshaled, err := json.Marshal(&server) + if err != nil { + return fmt.Errorf("error marshaling server state: %w", err) + } + + tr.Put(ctx, getServerKey(server.ServerID), marshaled) + } + + return nil +} + func (k *kvRegistry) GetVersionStamp( ctx context.Context, ) (int64, error) { @@ -500,7 +602,7 @@ func (k *kvRegistry) getActorBytes( actorBytes, ok, err := tr.Get(ctx, actorKey) if err != nil { return nil, false, fmt.Errorf( - "error getting actor bytes for key: %s", string(actorBytes)) + "error getting actor bytes for key '%s': %w", string(actorKey), err) } if !ok { return nil, false, nil @@ -550,10 +652,10 @@ func getServersPrefix() []byte { } type registeredActor struct { - Opts types.ActorOptions - ModuleID string - Generation uint64 - Activation activation + Opts types.ActorOptions + ModuleID string + Generation uint64 + Activations []activation } type registeredModule struct { @@ -630,24 +732,31 @@ func getLiveServers( return liveServers, nil } -// pickServerForActivation is responsible for deciding which server to activate an actor on. It +// pickServersForActivation is responsible for deciding which server(s) to activate an actor on. It // prioritizes activating actors on the server that currently has the lowest memory usage. However, // all else being equal, it will tiebreak by selecting the server with the lowest number of activated // actors. // // TODO: Would be nice to make this function pluggable/injectable for easier testing and to make the // system more flexible for more use cases. -func pickServerForActivation( +func pickServersForActivation( + n uint64, available []serverState, opts KVRegistryOptions, - blacklistedServerID string, - cachedServerID string, - isFirstTimeObservingActor bool, -) (serverState, string) { + isServerIDBlacklisted map[string]bool, + cachedServerIDs []string, + seen map[string]bool, +) (result []serverState, reason string) { if len(available) == 0 { panic("[invariant violated] pickServerForActivation should not be called with empty slice") } + // These variables are initialized as boolean values to indicate if the selection + // is derived from the cache (fromCache) or from heartbeat messages (fromHeartbeat). + var ( + fromCache, fromHeartbeat bool + ) + // If the caller told us which server the actor was previously activated on *and* that server // is still alive *and* that server is not the blacklisted server *and* this is the first time // this registry has seen this actor before then we "trust" the cache activation and activate @@ -656,12 +765,25 @@ func pickServerForActivation( // despite the new leader having very little state to go off of. Note that for this feature to // work properly the MinSuccessiveHeartbeatsBeforeAllowActivations option must be set to some // reasonable value (3 or 4 at least). - if cachedServerID != "" && cachedServerID != blacklistedServerID { - for _, s := range available { - if s.ServerID == cachedServerID { - return s, "from_client_cache" + serverCanHostActor := func(serverID string) bool { + return !isServerIDBlacklisted[serverID] && !seen[serverID] + } + + for _, cachedServerID := range cachedServerIDs { + if serverCanHostActor(cachedServerID) { + for _, server := range available { + if server.ServerID == cachedServerID { + result = append(result, server) + seen[cachedServerID] = true + fromCache = true + break + } } } + + if uint64(len(result)) >= n { + return result, selectionReason(fromCache, fromHeartbeat) + } } sort.Slice(available, func(i, j int) bool { @@ -688,11 +810,17 @@ func pickServerForActivation( return sI.HeartbeatState.NumActivatedActors < sJ.HeartbeatState.NumActivatedActors }) - selected := available[0] - if len(available) > 1 && selected.ServerID == blacklistedServerID { - selected = available[1] + for _, server := range available { + if serverCanHostActor(server.ServerID) { + result = append(result, server) + seen[server.ServerID] = true + fromHeartbeat = true + } + if uint64(len(result)) >= n { + return result, selectionReason(fromCache, fromHeartbeat) + } } - return selected, "based_on_heartbeat_state" + return result, selectionReason(fromCache, fromHeartbeat) } func minMaxMemUsage(available []serverState) (serverState, serverState) { @@ -715,3 +843,18 @@ func minMaxMemUsage(available []serverState) (serverState, serverState) { return minMemUsage, maxMemUsage } + +// selectionReason determines the reason for the server selection based on the provided flags. +// It returns a string indicating whether the selection is from cache, heartbeat, both, or none. +func selectionReason(fromCache bool, fromHeartbeat bool) string { + if fromCache && fromHeartbeat { + return "from_client_cache_and_heartbeat" + } + if fromCache { + return "from_client_cache" + } + if fromHeartbeat { + return "from_heartbeat" + } + return "none" +} diff --git a/virtual/registry/leaderregistry/leader_registry.go b/virtual/registry/leaderregistry/leader_registry.go index cb3d33f..2ad5482 100644 --- a/virtual/registry/leaderregistry/leader_registry.go +++ b/virtual/registry/leaderregistry/leader_registry.go @@ -293,7 +293,7 @@ func (a *leaderActor) handleEnsureActivation( activations := make([][]byte, 0, len(result.References)) for _, a := range result.References { - marshaled, err := a.MarshalJSON() + marshaled, err := json.Marshal(a) if err != nil { return nil, fmt.Errorf("error marshaling JSON for activation: %w", err) } diff --git a/virtual/registry/test_common.go b/virtual/registry/test_common.go index ce4c8ab..f770daf 100644 --- a/virtual/registry/test_common.go +++ b/virtual/registry/test_common.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/richardartoul/nola/virtual/types" "github.com/stretchr/testify/require" ) @@ -17,6 +18,14 @@ func TestAllCommon(t *testing.T, registryCtor func() Registry) { t.Run("service discovery and ensure activation", func(t *testing.T) { testRegistryServiceDiscoveryAndEnsureActivation(t, registryCtor()) }) + + t.Run("test registry replication", func(t *testing.T) { + testRegistryReplication(t, registryCtor()) + }) + + t.Run("test ensure activations persistence", func(t *testing.T) { + testEnsureActivationPersistence(t, registryCtor()) + }) } // testRegistryServiceDiscoveryAndEnsureActivation tests the combination of the @@ -60,14 +69,13 @@ func testRegistryServiceDiscoveryAndEnsureActivation(t *testing.T, registry Regi }) require.NoError(t, err) require.Equal(t, 1, len(activations.References)) - require.Equal(t, "server1", activations.References[0].ServerID()) - require.Equal(t, "server1_address", activations.References[0].Address()) - require.Equal(t, "ns1", activations.References[0].Namespace()) - require.Equal(t, "ns1", activations.References[0].ModuleID().Namespace) - require.Equal(t, "test-module1", activations.References[0].ModuleID().ID) - require.Equal(t, "ns1", activations.References[0].ActorID().Namespace) - require.Equal(t, "a", activations.References[0].ActorID().ID) - require.Equal(t, uint64(1), activations.References[0].Generation()) + require.Equal(t, "server1", activations.References[0].Physical.ServerID) + require.Equal(t, "server1_address", activations.References[0].Physical.ServerState.Address) + require.Equal(t, "ns1", activations.References[0].Virtual.Namespace) + require.Equal(t, "ns1", activations.References[0].Virtual.Namespace) + require.Equal(t, "test-module1", activations.References[0].Virtual.ModuleID) + require.Equal(t, "a", activations.References[0].Virtual.ActorID) + require.Equal(t, uint64(1), activations.References[0].Virtual.Generation) require.True(t, activations.VersionStamp > 0) prevVS := activations.VersionStamp @@ -78,14 +86,12 @@ func testRegistryServiceDiscoveryAndEnsureActivation(t *testing.T, registry Regi }) require.NoError(t, err) require.Equal(t, 1, len(activations.References)) - require.Equal(t, "server1", activations.References[0].ServerID()) - require.Equal(t, "server1_address", activations.References[0].Address()) - require.Equal(t, "ns1", activations.References[0].Namespace()) - require.Equal(t, "ns1", activations.References[0].ModuleID().Namespace) - require.Equal(t, "test-module1", activations.References[0].ModuleID().ID) - require.Equal(t, "ns1", activations.References[0].ActorID().Namespace) - require.Equal(t, "a", activations.References[0].ActorID().ID) - require.Equal(t, uint64(1), activations.References[0].Generation()) + require.Equal(t, "server1", activations.References[0].Physical.ServerID) + require.Equal(t, "server1_address", activations.References[0].Physical.ServerState.Address) + require.Equal(t, "ns1", activations.References[0].Virtual.Namespace) + require.Equal(t, "test-module1", activations.References[0].Virtual.ModuleID) + require.Equal(t, "a", activations.References[0].Virtual.ActorID) + require.Equal(t, uint64(1), activations.References[0].Virtual.Generation) require.True(t, activations.VersionStamp > prevVS) prevVS = activations.VersionStamp @@ -108,13 +114,11 @@ func testRegistryServiceDiscoveryAndEnsureActivation(t *testing.T, registry Regi }) require.NoError(t, err) require.Equal(t, 1, len(activations.References)) - require.Equal(t, "server1", activations.References[0].ServerID()) - require.Equal(t, "server1_address", activations.References[0].Address()) - require.Equal(t, "ns1", activations.References[0].Namespace()) - require.Equal(t, "ns1", activations.References[0].ModuleID().Namespace) - require.Equal(t, "test-module1", activations.References[0].ModuleID().ID) - require.Equal(t, "ns1", activations.References[0].ActorID().Namespace) - require.Equal(t, "a", activations.References[0].ActorID().ID) + require.Equal(t, "server1", activations.References[0].Physical.ServerID) + require.Equal(t, "server1_address", activations.References[0].Physical.ServerState.Address) + require.Equal(t, "ns1", activations.References[0].Virtual.Namespace) + require.Equal(t, "test-module1", activations.References[0].Virtual.ModuleID) + require.Equal(t, "a", activations.References[0].Virtual.ActorID) } // Reuse the same actor ID, but with a different module. The registry should consider @@ -126,13 +130,11 @@ func testRegistryServiceDiscoveryAndEnsureActivation(t *testing.T, registry Regi }) require.NoError(t, err) require.Equal(t, 1, len(activations.References)) - require.Equal(t, "server2", activations.References[0].ServerID()) - require.Equal(t, "server2_address", activations.References[0].Address()) - require.Equal(t, "ns1", activations.References[0].Namespace()) - require.Equal(t, "ns1", activations.References[0].ModuleID().Namespace) - require.Equal(t, "test-module2", activations.References[0].ModuleID().ID) - require.Equal(t, "ns1", activations.References[0].ActorID().Namespace) - require.Equal(t, "a", activations.References[0].ActorID().ID) + require.Equal(t, "server2", activations.References[0].Physical.ServerID) + require.Equal(t, "server2_address", activations.References[0].Physical.ServerState.Address) + require.Equal(t, "ns1", activations.References[0].Virtual.Namespace) + require.Equal(t, "test-module2", activations.References[0].Virtual.ModuleID) + require.Equal(t, "a", activations.References[0].Virtual.ActorID) require.True(t, activations.VersionStamp > prevVS) // Next 10 activations should all go to server2 for balancing purposes. @@ -145,7 +147,7 @@ func testRegistryServiceDiscoveryAndEnsureActivation(t *testing.T, registry Regi }) require.NoError(t, err) require.Equal(t, 1, len(activations.References)) - require.Equal(t, "server2", activations.References[0].ServerID()) + require.Equal(t, "server2", activations.References[0].Physical.ServerID) _, err = registry.Heartbeat(ctx, "server2", HeartbeatState{ NumActivatedActors: i + 1, @@ -168,16 +170,16 @@ func testRegistryServiceDiscoveryAndEnsureActivation(t *testing.T, registry Regi if lastServerID == "" { } else if lastServerID == "server1" { - require.Equal(t, "server2", activations.References[0].ServerID()) + require.Equal(t, "server2", activations.References[0].Physical.ServerID) } else { - require.Equal(t, "server1", activations.References[0].ServerID()) + require.Equal(t, "server1", activations.References[0].Physical.ServerID) } - _, err = registry.Heartbeat(ctx, activations.References[0].ServerID(), HeartbeatState{ + _, err = registry.Heartbeat(ctx, activations.References[0].Physical.ServerID, HeartbeatState{ NumActivatedActors: 10 + i + 1, - Address: fmt.Sprintf("%s_address", activations.References[0].ServerID()), + Address: fmt.Sprintf("%s_address", activations.References[0].Physical.ServerID), }) require.NoError(t, err) - lastServerID = activations.References[0].ServerID() + lastServerID = activations.References[0].Physical.ServerID } // Wait for server1's heartbeat to expire. @@ -204,6 +206,160 @@ func testRegistryServiceDiscoveryAndEnsureActivation(t *testing.T, registry Regi }) require.NoError(t, err) require.Equal(t, 1, len(activations.References)) - require.Equal(t, "server2", activations.References[0].ServerID()) + require.Equal(t, "server2", activations.References[0].Physical.ServerID) + } +} + +// testRegistryReplication is a test function that verifies the replication behavior of a registry implementation. +// The steps performed by this function are as follows: +// +// 1. Ensure Activation - Single Replica: +// The function calls the `EnsureActivation` method again, this time without requesting any additional replicas. +// This step verifies the successful activation of an actor with a single replica. The number of returned +// activation references is validated to ensure that only one reference is returned. +// +// 2. Ensure Activation - Extra Replica: +// The function calls the `EnsureActivation` method of the registry to ensure the activation of an actor with the +// given namespace, actor ID, and module ID. In this step, an additional replica is requested for the actor. +// The number of returned activation references is validated to ensure that two references are returned. +// +// 3. Ensure Activation - Extra Replicas: +// The function calls the `EnsureActivation` method once more, but this time requests two additional replicas +// for the actor. It's important to note that even though the function requests three replicas in total, the +// replication behavior is limited by the number of available servers. In this case, since there are only two +// servers ("server1" and "server2"), the maximum number of replicas that can be created is also limited to two. +// The purpose of this step is to test the successful activation of an actor with the maximum number of replicas +// that the available servers can accommodate. The number of returned activation references is validated to ensure +// that two references are returned, indicating that the replication behavior respects the available server +// resources and doesn't exceed the limit. +func testRegistryReplication(t *testing.T, registry Registry) { + ctx := context.Background() + defer registry.Close(ctx) + + for i := 0; i < 5; i++ { + // Heartbeat 5 times because some registry implementations (like the + // LeaderRegistry) require multiple successful heartbeats from at least + // 1 server before any actors can be placed. + heartbeatResult, err := registry.Heartbeat(ctx, "server1", HeartbeatState{ + NumActivatedActors: 10, + Address: "server1_address", + }) + require.NoError(t, err) + require.True(t, heartbeatResult.VersionStamp > 0) + require.Equal(t, HeartbeatTTL.Microseconds(), heartbeatResult.HeartbeatTTL) + + heartbeatResult, err = registry.Heartbeat(ctx, "server2", HeartbeatState{ + NumActivatedActors: 10, + Address: "server2_address", + }) + require.NoError(t, err) + require.True(t, heartbeatResult.VersionStamp > 0) + require.Equal(t, HeartbeatTTL.Microseconds(), heartbeatResult.HeartbeatTTL) } + + activations, err := registry.EnsureActivation(ctx, EnsureActivationRequest{ + Namespace: "ns1", + ActorID: "a", + ModuleID: "test-module1", + }) + require.NoError(t, err) + require.Equal(t, 1, len(activations.References)) + + activations, err = registry.EnsureActivation(ctx, EnsureActivationRequest{ + Namespace: "ns1", + ActorID: "b", + ModuleID: "test-module1", + ExtraReplicas: 1, + }) + require.NoError(t, err) + require.Equal(t, 2, len(activations.References)) + + activations, err = registry.EnsureActivation(ctx, EnsureActivationRequest{ + Namespace: "ns1", + ActorID: "c", + ModuleID: "test-module1", + ExtraReplicas: 2, + }) + require.NoError(t, err) + require.Equal(t, 2, len(activations.References)) +} + +// The purpose of this test function is to verify the persistence of actor activations across consecutive calls to the EnsureActivation function. +// The logic of the test involves calling the EnsureActivation function every microsecond for 5 seconds and expecting to consistently receive +// the same actor reference (server) in return. +// +// The test is designed to check whether activations are persisted correctly, meaning that the same actor reference should be returned unless +// the server is blacklisted or goes down. It assumes that if activations are persisted, the EnsureActivation function will consistently return +// the same reference, unless exceptional circumstances such as server blacklisting or failure occur, which are not expected during the test. +func testEnsureActivationPersistence(t *testing.T, registry Registry) { + const testDuration = 5 * time.Second + + ctx, cc := context.WithCancel(context.Background()) + defer cc() + defer registry.Close(ctx) + + for i := 0; i < 5; i++ { + // Heartbeat 5 times because some registry implementations (like the + // LeaderRegistry) require multiple successful heartbeats from at least + // 1 server before any actors can be placed. + heartbeatResult, err := registry.Heartbeat(ctx, "server1", HeartbeatState{ + NumActivatedActors: 10, + Address: "server1_address", + }) + require.NoError(t, err) + require.True(t, heartbeatResult.VersionStamp > 0) + require.Equal(t, HeartbeatTTL.Microseconds(), heartbeatResult.HeartbeatTTL) + + heartbeatResult, err = registry.Heartbeat(ctx, "server2", HeartbeatState{ + NumActivatedActors: 10, + Address: "server2_address", + }) + require.NoError(t, err) + require.True(t, heartbeatResult.VersionStamp > 0) + require.Equal(t, HeartbeatTTL.Microseconds(), heartbeatResult.HeartbeatTTL) + } + + go func() { + // This goroutine simulates heartbeats while the test is running. + for ctx.Err() == nil { + // Perform a heartbeat for "server1" + registry.Heartbeat(ctx, "server1", HeartbeatState{ + NumActivatedActors: 10, + Address: "server1_address", + }) + + // Perform a heartbeat for "server2" + registry.Heartbeat(ctx, "server2", HeartbeatState{ + NumActivatedActors: 10, + Address: "server2_address", + }) + + // Wait for HeartbeatTTL / 2 before sending the next heartbeat + select { + case <-time.After(HeartbeatTTL / 2): + case <-ctx.Done(): + return + } + } + }() + + var ref types.ActorReference + require.Never(t, func() bool { + activations, err := registry.EnsureActivation(ctx, EnsureActivationRequest{ + Namespace: "ns1", + ActorID: "a", + ModuleID: "test-module1", + }) + require.NoError(t, err) + require.Equal(t, 1, len(activations.References)) + differentActivation := !(ref == types.ActorReference{} || ref == activations.References[0]) + ref = activations.References[0] + return differentActivation + }, testDuration, time.Microsecond, "actor has been activated in more than one server") + + // Sleeping for a second is necessary to ensure that the last call to registry.EnsureActivation + // finishes executing. This is important because the condition is called asynchronously, and + // there is a possibility of encountering an error if the registry has been closed before + // completion. + time.Sleep(time.Second) } diff --git a/virtual/registry/types.go b/virtual/registry/types.go index aaa771f..715d76d 100644 --- a/virtual/registry/types.go +++ b/virtual/registry/types.go @@ -116,7 +116,12 @@ type EnsureActivationRequest struct { ModuleID string `json:"module_id"` ActorID string `json:"actor_id"` - // BlacklistedServerID is set if the caller is calling the EnsureActivation method + // ExtraReplicas represents the number of additional replicas requested for an actor. + // It specifies the desired number of replicas, in addition to the primary replica, + // that should be created during actor activation. + // The value of ExtraReplicas should be a non-negative integer. + ExtraReplicas uint64 `json:"extra_replicas"` + // BlacklistedServerIDs is set if the caller is calling the EnsureActivation method // after receiving an error from the server the actor is *supposed* to be activated // on that the server has blacklisted the actor. The server may blacklist the actor // temporarily due to excessive resource consumption and/or to accomplish balancing @@ -124,8 +129,8 @@ type EnsureActivationRequest struct { // the ID of the server that the actor was blacklisted on so the registry can keep // track of that information and ensure the actor is activated elsewhere / balanced // properly. - BlacklistedServerID string `json:"blacklisted_server_id"` - CachedActivationServerIDs []string `json:"cached_activation_server_id"` + BlacklistedServerIDs []string `json:"blacklisted_server_ids"` + CachedActivationServerIDs []string `json:"cached_activation_server_ids"` } // EnsureActivationResult contains the result of invoking the EnsureActivation method. diff --git a/virtual/types/ref.go b/virtual/types/ref.go index 2fc075f..224213b 100644 --- a/virtual/types/ref.go +++ b/virtual/types/ref.go @@ -6,41 +6,36 @@ import ( "fmt" ) -type actorRef struct { - virtualRef *virtualRef - serverID string - serverVersion int64 - address string -} - // NewActorReference creates an ActorReference. func NewActorReference( serverID string, serverVersion int64, - address string, namespace string, moduleID string, actorID string, generation uint64, + serverState ServerState, ) (ActorReference, error) { virtual, err := NewVirtualActorReference(namespace, moduleID, actorID, generation) if err != nil { - return nil, fmt.Errorf("NewActorReference: error creating new virtual reference: %w", err) + return ActorReference{}, fmt.Errorf("NewActorReference: error creating new virtual reference: %w", err) } if serverID == "" { - return nil, errors.New("serverID cannot be empty") + return ActorReference{}, errors.New("serverID cannot be empty") } - if address == "" { - return nil, errors.New("address cannot be empty") + if serverState.Address == "" { + return ActorReference{}, errors.New("address cannot be empty") } - vr := virtual.(virtualRef) - return &actorRef{ - virtualRef: &vr, - serverID: serverID, - serverVersion: serverVersion, - address: address, + return ActorReference{ + Virtual: virtual, + Physical: ActorReferencePhysical{ + ServerID: serverID, + ServerVersion: serverVersion, + ServerState: serverState, + }, + Type: ReferenceTypeLocal, }, nil } @@ -48,83 +43,78 @@ func NewActorReference( // the in-memory representation of the ActorReference that was previously marshaled // by calling MarshalJSON on the ActorReference. func NewActorReferenceFromJSON(data []byte) (ActorReference, error) { - var serializable serializableActorRef - if err := json.Unmarshal(data, &serializable); err != nil { - return nil, err + var ref ActorReference + if err := json.Unmarshal(data, &ref); err != nil { + return ActorReference{}, fmt.Errorf("error creating new actor reference after JSON unmarshal: %w", err) } - ref, err := NewActorReference( - serializable.ServerID, - serializable.ServerVersion, - serializable.Address, - serializable.Namespace, - serializable.ModuleID, - serializable.ActorID, - serializable.Generation) - if err != nil { - return nil, fmt.Errorf("error creating new actor reference after JSON unmarshal: %w", err) - } - ref.(*actorRef).virtualRef.idType = serializable.IDType + ref.Type = ReferenceTypeLocal return ref, nil } -func (l actorRef) Type() ReferenceType { - return ReferenceTypeLocal +// ActorReference abstracts over different forms of ReferenceType. It provides all the +// necessary information for communicating with an actor. Some of the fields are "logical" +type ActorReference struct { + Type ReferenceType `json:"-"` + Virtual ActorReferenceVirtual `json:"virtual"` + Physical ActorReferencePhysical `json:"physical"` } -func (l actorRef) ServerID() string { - return l.serverID -} - -func (l actorRef) ServerVersion() int64 { - return l.serverVersion -} - -func (l actorRef) Namespace() string { - return l.virtualRef.Namespace() -} +// ActorReferenceVirtual is the subset of data in ActorReference that is "virtual" and has +// nothing to do with the physical location of the actor's activation. The virtual fields +// are all that is required for the Registry to resolve a physical reference. +type ActorReferenceVirtual struct { + // Namespace is the namespace to which this ActorReference belongs. + Namespace string `json:"namespace"` + // ModuleID is the ID of the WASM module that this actor is instantiated from. + ModuleID string `json:"module_id"` + // The ID of the referenced actor. + ActorID string `json:"actor_id"` + // Generation represents the generation count for the actor's activation. This value + // may be bumped by the registry at any time to signal to the rest of the system that + // all outstanding activations should be recreated for whatever reason. + Generation uint64 `json:"generation"` -func (l actorRef) ActorID() NamespacedActorID { - return l.virtualRef.ActorID() -} + // IDType allows us to ensure that an actor and a worker with the + // same tuple of are still + // namespaced away from each other in any in-memory datastructures. + IDType string `json:"id_type"` -func (l actorRef) ModuleID() NamespacedID { - return l.virtualRef.ModuleID() + // Buffers for Namespaced ActorID and ModuleID + actorIDWithNamespace NamespacedActorID `json:"-"` + moduleIDWithNamespace NamespacedID `json:"-"` } -func (l actorRef) Address() string { - return l.address +func (ref *ActorReferenceVirtual) ActorIDWithNamespace() NamespacedActorID { + if ref.actorIDWithNamespace.ID == "" { + ref.actorIDWithNamespace = NewNamespacedActorID(ref.Namespace, ref.ActorID, ref.ModuleID, ref.IDType) + } + return ref.actorIDWithNamespace } -func (l actorRef) Generation() uint64 { - return l.virtualRef.Generation() +func (ref *ActorReferenceVirtual) ModuleIDWithNamespace() NamespacedID { + if ref.moduleIDWithNamespace.ID == "" { + ref.moduleIDWithNamespace = NewNamespacedID(ref.Namespace, ref.ModuleID, ref.IDType) + } + return ref.moduleIDWithNamespace } -func (l actorRef) MarshalJSON() ([]byte, error) { - // This is terrible, I'm sorry. - return json.Marshal(&serializableActorRef{ - Namespace: l.Namespace(), - ModuleID: l.virtualRef.ModuleID().ID, - ActorID: l.virtualRef.ActorID().ID, - Generation: l.virtualRef.Generation(), - IDType: l.virtualRef.idType, - ServerID: l.serverID, - ServerVersion: l.serverVersion, - Address: l.address, - }) +// ActorReferencePhysical is the subset of data in ActorReference that is "physical" and +// that is used to actually find and communicate with the actor's current activation. +type ActorReferencePhysical struct { + // ServerID is the ID of the physical server that this reference targets. + ServerID string `json:"server_id"` + // ServerVersion is incremented every time a server's heartbeat expires and resumes, + // guaranteeing the server's ability to identify periods of inactivity/death for correctness purposes. + ServerVersion int64 `json:"server_version"` + + // The state of the physical server that this reference targets. + // Contains information that is sent in the heartbeat. + ServerState ServerState `json:"server_state"` } -type serializableActorRef struct { - Namespace string `json:"namespace"` - ModuleID string `json:"moduleID"` - ActorID string `json:"actorID"` - Generation uint64 `json:"generation"` - // idType allows us to ensure that an actor and a worker with the - // same tuple of are still - // namespaced away from each other in any in-memory datastructures. - IDType string `json:"idType"` - ServerID string `json:"server_id"` - ServerVersion int64 `json:"server_version"` - Address string `json:"address"` +type ServerState struct { + // Address is the address at which the server can be reached. + Address string `json:"address"` } diff --git a/virtual/types/ref_test.go b/virtual/types/ref_test.go index 7d2942d..6635ae7 100644 --- a/virtual/types/ref_test.go +++ b/virtual/types/ref_test.go @@ -1,26 +1,25 @@ package types import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" ) func TestNewActorReference(t *testing.T) { - ref, err := NewActorReference("server1", 0, "server1path", "a", "b", "c", 1) + ref, err := NewActorReference("server1", 0, "a", "b", "c", 1, ServerState{Address: "server1path"}) require.NoError(t, err) - require.Equal(t, "server1", ref.ServerID()) - require.Equal(t, "server1path", ref.Address()) - require.Equal(t, "a", ref.Namespace()) - require.Equal(t, "a", ref.ActorID().Namespace) - require.Equal(t, "c", ref.ActorID().ID) - require.Equal(t, "a", ref.ModuleID().Namespace) - require.Equal(t, "b", ref.ModuleID().ID) - require.Equal(t, uint64(1), ref.Generation()) - require.Equal(t, IDTypeActor, ref.ActorID().IDType) - - marshaled, err := ref.MarshalJSON() + require.Equal(t, "server1", ref.Physical.ServerID) + require.Equal(t, "server1path", ref.Physical.ServerState.Address) + require.Equal(t, "a", ref.Virtual.Namespace) + require.Equal(t, "c", ref.Virtual.ActorID) + require.Equal(t, "b", ref.Virtual.ModuleID) + require.Equal(t, uint64(1), ref.Virtual.Generation) + require.Equal(t, IDTypeActor, ref.Virtual.IDType) + + marshaled, err := json.Marshal(ref) require.NoError(t, err) unmarshaled, err := NewActorReferenceFromJSON(marshaled) @@ -30,21 +29,19 @@ func TestNewActorReference(t *testing.T) { } func TestNewWorkerReference(t *testing.T) { - ref, err := NewActorReference("server1", 0, "server1path", "a", "b", "c", 1) + ref, err := NewActorReference("server1", 0, "a", "b", "c", 1, ServerState{Address: "server1path"}) require.NoError(t, err) - ref.(*actorRef).virtualRef.idType = IDTypeWorker - - require.Equal(t, "server1", ref.ServerID()) - require.Equal(t, "server1path", ref.Address()) - require.Equal(t, "a", ref.Namespace()) - require.Equal(t, "a", ref.ActorID().Namespace) - require.Equal(t, "c", ref.ActorID().ID) - require.Equal(t, "a", ref.ModuleID().Namespace) - require.Equal(t, "b", ref.ModuleID().ID) - require.Equal(t, uint64(1), ref.Generation()) - require.Equal(t, IDTypeWorker, ref.ActorID().IDType) - - marshaled, err := ref.MarshalJSON() + ref.Virtual.IDType = IDTypeWorker + + require.Equal(t, "server1", ref.Physical.ServerID) + require.Equal(t, "server1path", ref.Physical.ServerState.Address) + require.Equal(t, "a", ref.Virtual.Namespace) + require.Equal(t, "b", ref.Virtual.ModuleID) + require.Equal(t, "c", ref.Virtual.ActorID) + require.Equal(t, uint64(1), ref.Virtual.Generation) + require.Equal(t, IDTypeWorker, ref.Virtual.IDType) + + marshaled, err := json.Marshal(ref) require.NoError(t, err) unmarshaled, err := NewActorReferenceFromJSON(marshaled) diff --git a/virtual/types/req.go b/virtual/types/req.go index 580bb4d..d232b4d 100644 --- a/virtual/types/req.go +++ b/virtual/types/req.go @@ -29,4 +29,9 @@ type CreateIfNotExist struct { // ActorOptions contains the options for a given actor. type ActorOptions struct { + // ExtraReplicas represents the number of additional replicas requested for an actor. + // It specifies the desired number of replicas, in addition to the primary replica, + // that should be created during actor activation. + // The value of ExtraReplicas should be a non-negative integer. + ExtraReplicas uint64 `json:"extra_replicas"` } diff --git a/virtual/types/types.go b/virtual/types/types.go index 49ba46d..97da1fd 100644 --- a/virtual/types/types.go +++ b/virtual/types/types.go @@ -1,7 +1,5 @@ package types -import "encoding/json" - // ReferenceType is an enum type that indicates what the underlying type of Reference is, // see the different ReferenceType's below. type ReferenceType string @@ -15,39 +13,14 @@ const ( ReferenceTypeRemoteHTTP ReferenceType = "remote-http" ) -// ActorReference abstracts over different forms of ReferenceType. It provides all the -// necessary information for communicating with an actor. Some of the fields are "logical" -type ActorReference interface { - json.Marshaler - - ActorReferenceVirtual - ActorReferencePhysical -} - -// ActorReferenceVirtual is the subset of data in ActorReference that is "virtual" and has -// nothing to do with the physical location of the actor's activation. The virtual fields -// are all that is required for the Registry to resolve a physical reference. -type ActorReferenceVirtual interface { - // Namespace is the namespace to which this ActorReference belongs. - Namespace() string - // ModuleID is the ID of the WASM module that this actor is instantiated from. - ModuleID() NamespacedID - // The ID of the referenced actor. - ActorID() NamespacedActorID - // Generation represents the generation count for the actor's activation. This value - // may be bumped by the registry at any time to signal to the rest of the system that - // all outstanding activations should be recreated for whatever reason. - Generation() uint64 -} +func StringSliceToSet(slice []string) map[string]bool { + if slice == nil { + return make(map[string]bool) + } -// ActorReferencePhysical is the subset of data in ActorReference that is "physical" and -// that is used to actually find and communicate with the actor's current activation. -type ActorReferencePhysical interface { - // ServerID is the ID of the physical server that this reference targets. - ServerID() string - // The address of the referenced actor. - Address() string - // ServerVersion is incremented every time a server's heartbeat expires and resumes, - // guaranteeing the server's ability to identify periods of inactivity/death for correctness purposes. - ServerVersion() int64 + set := make(map[string]bool, len(slice)) + for _, item := range slice { + set[item] = true + } + return set } diff --git a/virtual/types/virtual_ref.go b/virtual/types/virtual_ref.go index 12dd17f..739043d 100644 --- a/virtual/types/virtual_ref.go +++ b/virtual/types/virtual_ref.go @@ -10,17 +10,6 @@ const ( IDTypeWorker = "worker" ) -type virtualRef struct { - namespace string - moduleID string - actorID string - generation uint64 - // idType allows us to ensure that an actor and a worker with the - // same tuple of are still - // namespaced away from each other in any in-memory datastructures. - idType string -} - // func NewVirtualWorkerReference creates a new VirtualActorReference // for a given worker. func NewVirtualWorkerReference( @@ -56,39 +45,23 @@ func newVirtualActorReference( idType string, ) (ActorReferenceVirtual, error) { if namespace == "" { - return nil, errors.New("namespace cannot be empty") + return ActorReferenceVirtual{}, errors.New("namespace cannot be empty") } if moduleID == "" { - return nil, errors.New("moduleID cannot be empty") + return ActorReferenceVirtual{}, errors.New("moduleID cannot be empty") } if actorID == "" { - return nil, errors.New("moduleID cannot be empty") + return ActorReferenceVirtual{}, errors.New("moduleID cannot be empty") } if generation <= 0 { - return nil, errors.New("generation must be >0") + return ActorReferenceVirtual{}, errors.New("generation must be >0") } - return virtualRef{ - namespace: namespace, - moduleID: moduleID, - actorID: actorID, - generation: generation, - idType: idType, + return ActorReferenceVirtual{ + Namespace: namespace, + ModuleID: moduleID, + ActorID: actorID, + Generation: generation, + IDType: idType, }, nil } - -func (l virtualRef) Namespace() string { - return l.namespace -} - -func (l virtualRef) ActorID() NamespacedActorID { - return NewNamespacedActorID(l.namespace, l.actorID, l.moduleID, l.idType) -} - -func (l virtualRef) ModuleID() NamespacedID { - return NewNamespacedID(l.namespace, l.moduleID, l.idType) -} - -func (l virtualRef) Generation() uint64 { - return l.generation -} diff --git a/virtual/wazero.go b/virtual/wazero.go index c74f3d0..d2b2801 100644 --- a/virtual/wazero.go +++ b/virtual/wazero.go @@ -108,11 +108,11 @@ func newHostFnRouter( func extractActorRef(ctx context.Context) (types.ActorReferenceVirtual, error) { actorRefIface := ctx.Value(hostFnActorReferenceCtxKey{}) if actorRefIface == nil { - return nil, fmt.Errorf("wazeroHostFnRouter: could not find non-empty actor reference in context") + return types.ActorReferenceVirtual{}, fmt.Errorf("wazeroHostFnRouter: could not find non-empty actor reference in context") } actorRef, ok := actorRefIface.(types.ActorReferenceVirtual) if !ok { - return nil, fmt.Errorf("wazeroHostFnRouter: wrong type for actor reference in context: %T", actorRef) + return types.ActorReferenceVirtual{}, fmt.Errorf("wazeroHostFnRouter: wrong type for actor reference in context: %T", actorRef) } return actorRef, nil } @@ -127,7 +127,7 @@ func (w wazeroModule) Instantiate( instantiatePayload []byte, host HostCapabilities, ) (Actor, error) { - obj, err := w.m.Instantiate(ctx, reference.ActorID().ID) + obj, err := w.m.Instantiate(ctx, reference.ActorID) if err != nil { return nil, err }